Skip to main content

webtrans_proto/
connect.rs

1//! Encode and decode WebTransport CONNECT requests and responses.
2
3use std::{str::FromStr, sync::Arc};
4
5use bytes::{Buf, BufMut, BytesMut};
6use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
7use url::Url;
8
9use super::{Frame, VarInt, qpack};
10use crate::io::read_incremental;
11
12use thiserror::Error;
13
14// Errors that can occur while processing a CONNECT request.
15#[derive(Error, Debug, Clone)]
16/// Errors returned while encoding or decoding WebTransport CONNECT messages.
17pub enum ConnectError {
18    #[error("unexpected end of input")]
19    /// Input ended before the complete frame/header block was available.
20    UnexpectedEnd,
21
22    #[error("qpack error")]
23    /// QPACK header block decoding failed.
24    QpackError(#[from] qpack::DecodeError),
25
26    #[error("unexpected frame {0:?}")]
27    /// Received an unexpected HTTP/3 frame type.
28    UnexpectedFrame(Frame),
29
30    #[error("invalid method")]
31    /// `:method` could not be parsed as an HTTP method.
32    InvalidMethod,
33
34    #[error("invalid url")]
35    /// URL reconstruction from pseudo-headers failed.
36    InvalidUrl(#[from] url::ParseError),
37
38    #[error("invalid status")]
39    /// `:status` could not be parsed as a valid HTTP status code.
40    InvalidStatus,
41
42    #[error("expected 200, got: {0:?}")]
43    /// Response status was not successful.
44    WrongStatus(Option<http::StatusCode>),
45
46    #[error("expected connect, got: {0:?}")]
47    /// Request method was not CONNECT.
48    WrongMethod(Option<http::method::Method>),
49
50    #[error("expected https, got: {0:?}")]
51    /// Request scheme was not HTTPS.
52    WrongScheme(Option<String>),
53
54    #[error("expected authority header")]
55    /// Required `:authority` pseudo-header was missing.
56    WrongAuthority,
57
58    #[error("expected webtransport, got: {0:?}")]
59    /// CONNECT protocol was not `webtransport`.
60    WrongProtocol(Option<String>),
61
62    #[error("expected path header")]
63    /// Required `:path` pseudo-header was missing.
64    WrongPath,
65
66    #[error("non-200 status: {0:?}")]
67    /// Peer returned an explicit non-2xx HTTP response status.
68    ErrorStatus(http::StatusCode),
69
70    #[error("io error: {0}")]
71    /// I/O error while reading or writing frames.
72    Io(Arc<std::io::Error>),
73}
74
75impl From<std::io::Error> for ConnectError {
76    fn from(err: std::io::Error) -> Self {
77        ConnectError::Io(Arc::new(err))
78    }
79}
80
81#[derive(Debug)]
82/// Decoded WebTransport CONNECT request metadata.
83pub struct ConnectRequest {
84    /// Target URL reconstructed from pseudo-headers.
85    pub url: Url,
86}
87
88impl ConnectRequest {
89    /// Decode a CONNECT request from an in-memory frame buffer.
90    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
91        let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
92        if typ != Frame::HEADERS {
93            return Err(ConnectError::UnexpectedFrame(typ));
94        }
95
96        // The frame payload should be complete, so an additional UnexpectedEnd is unexpected here.
97
98        let headers = qpack::Headers::decode(&mut data)?;
99
100        let scheme = match headers.get(":scheme") {
101            Some("https") => "https",
102            Some(scheme) => Err(ConnectError::WrongScheme(Some(scheme.to_string())))?,
103            None => return Err(ConnectError::WrongScheme(None)),
104        };
105
106        let authority = headers
107            .get(":authority")
108            .ok_or(ConnectError::WrongAuthority)?;
109
110        let path_and_query = headers.get(":path").ok_or(ConnectError::WrongPath)?;
111
112        let method = headers.get(":method");
113        match method
114            .map(|method| method.try_into().map_err(|_| ConnectError::InvalidMethod))
115            .transpose()?
116        {
117            Some(http::Method::CONNECT) => (),
118            o => return Err(ConnectError::WrongMethod(o)),
119        };
120
121        let protocol = headers.get(":protocol");
122        if protocol != Some("webtransport") {
123            return Err(ConnectError::WrongProtocol(protocol.map(|s| s.to_string())));
124        }
125
126        let url = Url::parse(&format!("{scheme}://{authority}{path_and_query}"))?;
127
128        Ok(Self { url })
129    }
130
131    /// Read and decode a CONNECT request from an async stream.
132    pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, ConnectError> {
133        read_incremental(
134            stream,
135            |cursor| Self::decode(cursor),
136            |err| matches!(err, ConnectError::UnexpectedEnd),
137            ConnectError::UnexpectedEnd,
138        )
139        .await
140    }
141
142    /// Encode this CONNECT request into a HEADERS frame.
143    pub fn encode<B: BufMut>(&self, buf: &mut B) {
144        let mut headers = qpack::Headers::default();
145        headers.set(":method", "CONNECT");
146        headers.set(":scheme", self.url.scheme());
147        headers.set(":authority", self.url.authority());
148        let path_and_query = match self.url.query() {
149            Some(query) => format!("{}?{}", self.url.path(), query),
150            None => self.url.path().to_string(),
151        };
152        headers.set(":path", &path_and_query);
153        headers.set(":protocol", "webtransport");
154        encode_headers_frame(buf, &headers);
155    }
156
157    /// Encode and write this CONNECT request to an async stream.
158    pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), ConnectError> {
159        let mut buf = BytesMut::new();
160        self.encode(&mut buf);
161        stream.write_all_buf(&mut buf).await?;
162        Ok(())
163    }
164}
165
166#[derive(Debug)]
167/// Decoded WebTransport CONNECT response metadata.
168pub struct ConnectResponse {
169    /// HTTP status returned by the peer.
170    pub status: http::status::StatusCode,
171}
172
173impl ConnectResponse {
174    /// Decode a CONNECT response from an in-memory frame buffer.
175    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
176        let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
177        if typ != Frame::HEADERS {
178            return Err(ConnectError::UnexpectedFrame(typ));
179        }
180
181        let headers = qpack::Headers::decode(&mut data)?;
182
183        let status = match headers
184            .get(":status")
185            .map(|status| {
186                http::StatusCode::from_str(status).map_err(|_| ConnectError::InvalidStatus)
187            })
188            .transpose()?
189        {
190            Some(status) if status.is_success() => status,
191            o => return Err(ConnectError::WrongStatus(o)),
192        };
193
194        Ok(Self { status })
195    }
196
197    /// Read and decode a CONNECT response from an async stream.
198    pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, ConnectError> {
199        read_incremental(
200            stream,
201            |cursor| Self::decode(cursor),
202            |err| matches!(err, ConnectError::UnexpectedEnd),
203            ConnectError::UnexpectedEnd,
204        )
205        .await
206    }
207
208    /// Encode this CONNECT response into a HEADERS frame.
209    pub fn encode<B: BufMut>(&self, buf: &mut B) {
210        let mut headers = qpack::Headers::default();
211        headers.set(":status", self.status.as_str());
212        headers.set("sec-webtransport-http3-draft", "draft02");
213        encode_headers_frame(buf, &headers);
214    }
215
216    /// Encode and write this CONNECT response to an async stream.
217    pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), ConnectError> {
218        let mut buf = BytesMut::new();
219        self.encode(&mut buf);
220        stream.write_all_buf(&mut buf).await?;
221        Ok(())
222    }
223}
224
225fn encode_headers_frame<B: BufMut>(buf: &mut B, headers: &qpack::Headers) {
226    // Encode headers into a temporary payload so we can prefix it with the frame length.
227    let mut payload = Vec::new();
228    headers.encode(&mut payload);
229
230    Frame::HEADERS.encode(buf);
231    VarInt::from_u32(payload.len() as u32).encode(buf);
232    buf.put_slice(&payload);
233}