ant_quic/masque/
connect.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! HTTP CONNECT-UDP Bind Request/Response Types
9//!
10//! Implements the HTTP Extended CONNECT mechanism for establishing MASQUE relay
11//! connections per RFC 9298 (CONNECT-UDP) and draft-ietf-masque-connect-udp-listen-10.
12//!
13//! # Protocol Overview
14//!
15//! CONNECT-UDP uses HTTP Extended CONNECT (RFC 8441) over HTTP/3:
16//!
17//! ```text
18//! Client                                          Relay
19//!   |                                               |
20//!   |  HEADERS (Extended CONNECT with :protocol)    |
21//!   |---------------------------------------------->|
22//!   |                                               |
23//!   |  HEADERS (200 OK + Proxy-Public-Address)      |
24//!   |<----------------------------------------------|
25//!   |                                               |
26//!   |  <-- Capsules and Datagrams flow -->          |
27//! ```
28//!
29//! # CONNECT-UDP Bind Extension
30//!
31//! The bind extension allows requesting a public address for inbound connections:
32//! - Target host `"::"` indicates bind-any (IPv4 and IPv6)
33//! - Target port `0` indicates let the relay choose a port
34//! - The relay responds with the public address it allocated
35//!
36//! # Example
37//!
38//! ```rust
39//! use ant_quic::masque::connect::{ConnectUdpRequest, ConnectUdpResponse};
40//! use std::net::{SocketAddr, IpAddr, Ipv4Addr};
41//!
42//! // Create a bind request
43//! let request = ConnectUdpRequest::bind_any();
44//! assert!(request.is_bind_request());
45//!
46//! // Create a targeted request
47//! let target = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), 8080);
48//! let request = ConnectUdpRequest::target(target);
49//! assert!(!request.is_bind_request());
50//!
51//! // Parse a successful response
52//! let response = ConnectUdpResponse::success(
53//!     Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000))
54//! );
55//! assert!(response.is_success());
56//! ```
57
58use bytes::{Buf, BufMut, Bytes, BytesMut};
59use std::fmt;
60use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
61use thiserror::Error;
62
63use crate::VarInt;
64use crate::coding::{self, Codec};
65
66/// The protocol identifier for Extended CONNECT
67pub const CONNECT_UDP_PROTOCOL: &str = "connect-udp";
68
69/// The protocol identifier for CONNECT-UDP Bind extension
70pub const CONNECT_UDP_BIND_PROTOCOL: &str = "connect-udp-bind";
71
72/// Bind-any host (indicates relay should choose)
73pub const BIND_ANY_HOST: &str = "::";
74
75/// Bind-any port (indicates relay should choose)
76pub const BIND_ANY_PORT: u16 = 0;
77
78/// Errors that can occur during CONNECT-UDP processing
79#[derive(Debug, Error)]
80pub enum ConnectError {
81    /// Invalid request format
82    #[error("invalid request: {0}")]
83    InvalidRequest(String),
84
85    /// Invalid response format
86    #[error("invalid response: {0}")]
87    InvalidResponse(String),
88
89    /// Request was rejected by relay
90    #[error("rejected: status {status}, reason: {reason}")]
91    Rejected {
92        /// HTTP status code
93        status: u16,
94        /// Human-readable reason
95        reason: String,
96    },
97
98    /// Encoding/decoding error
99    #[error("codec error")]
100    Codec,
101
102    /// Connection failed
103    #[error("connection failed: {0}")]
104    ConnectionFailed(String),
105}
106
107/// HTTP CONNECT-UDP Request
108///
109/// Represents an Extended CONNECT request for establishing a UDP proxy session.
110/// Can be either a targeted request (proxy to specific destination) or a bind
111/// request (request public address for inbound connections).
112#[derive(Debug, Clone, PartialEq, Eq)]
113pub struct ConnectUdpRequest {
114    /// Target host ("::" for bind-any)
115    pub target_host: String,
116    /// Target port (0 for bind-any)
117    pub target_port: u16,
118    /// Whether this is a bind request (vs. targeted proxy)
119    pub connect_udp_bind: bool,
120}
121
122impl ConnectUdpRequest {
123    /// Create a bind-any request
124    ///
125    /// Requests the relay allocate a public address for receiving inbound
126    /// connections. The relay will choose both the IP and port.
127    pub fn bind_any() -> Self {
128        Self {
129            target_host: BIND_ANY_HOST.to_string(),
130            target_port: BIND_ANY_PORT,
131            connect_udp_bind: true,
132        }
133    }
134
135    /// Create a bind request for a specific port
136    ///
137    /// Requests the relay allocate a public address with a specific port.
138    /// The relay may reject this if the port is unavailable.
139    pub fn bind_port(port: u16) -> Self {
140        Self {
141            target_host: BIND_ANY_HOST.to_string(),
142            target_port: port,
143            connect_udp_bind: true,
144        }
145    }
146
147    /// Create a targeted proxy request
148    ///
149    /// Requests the relay forward UDP traffic to a specific destination.
150    /// This is the standard CONNECT-UDP mode (not bind).
151    pub fn target(addr: SocketAddr) -> Self {
152        Self {
153            target_host: addr.ip().to_string(),
154            target_port: addr.port(),
155            connect_udp_bind: false,
156        }
157    }
158
159    /// Check if this is a bind request
160    pub fn is_bind_request(&self) -> bool {
161        self.connect_udp_bind
162    }
163
164    /// Check if this is a bind-any request (both host and port unspecified)
165    pub fn is_bind_any(&self) -> bool {
166        self.connect_udp_bind
167            && (self.target_host == BIND_ANY_HOST || self.target_host == "0.0.0.0")
168            && self.target_port == BIND_ANY_PORT
169    }
170
171    /// Get the target socket address if this is a targeted request
172    pub fn target_addr(&self) -> Option<SocketAddr> {
173        if self.is_bind_request() {
174            return None;
175        }
176
177        let ip: IpAddr = self.target_host.parse().ok()?;
178        Some(SocketAddr::new(ip, self.target_port))
179    }
180
181    /// Get the protocol string for HTTP headers
182    pub fn protocol(&self) -> &'static str {
183        if self.connect_udp_bind {
184            CONNECT_UDP_BIND_PROTOCOL
185        } else {
186            CONNECT_UDP_PROTOCOL
187        }
188    }
189
190    /// Encode the request as a wire format message
191    ///
192    /// Format: `[flags (1)] [host_len (varint)] [host] [port (2)]`
193    pub fn encode(&self) -> Bytes {
194        let mut buf = BytesMut::new();
195
196        // Flags byte: bit 0 = connect_udp_bind
197        let flags: u8 = if self.connect_udp_bind { 0x01 } else { 0x00 };
198        buf.put_u8(flags);
199
200        // Host length and host
201        let host_bytes = self.target_host.as_bytes();
202        if let Ok(len) = VarInt::from_u64(host_bytes.len() as u64) {
203            len.encode(&mut buf);
204        }
205        buf.put_slice(host_bytes);
206
207        // Port (network byte order)
208        buf.put_u16(self.target_port);
209
210        buf.freeze()
211    }
212
213    /// Decode a request from wire format
214    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
215        if buf.remaining() < 1 {
216            return Err(ConnectError::InvalidRequest("buffer too short".into()));
217        }
218
219        let flags = buf.get_u8();
220        let connect_udp_bind = (flags & 0x01) != 0;
221
222        let host_len = VarInt::decode(buf)
223            .map_err(|_| ConnectError::InvalidRequest("invalid host length".into()))?;
224        let host_len = host_len.into_inner() as usize;
225
226        if buf.remaining() < host_len + 2 {
227            return Err(ConnectError::InvalidRequest(
228                "buffer too short for host".into(),
229            ));
230        }
231
232        let mut host_bytes = vec![0u8; host_len];
233        buf.copy_to_slice(&mut host_bytes);
234        let target_host = String::from_utf8(host_bytes)
235            .map_err(|_| ConnectError::InvalidRequest("invalid UTF-8 in host".into()))?;
236
237        let target_port = buf.get_u16();
238
239        Ok(Self {
240            target_host,
241            target_port,
242            connect_udp_bind,
243        })
244    }
245}
246
247impl fmt::Display for ConnectUdpRequest {
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        if self.is_bind_request() {
250            write!(
251                f,
252                "CONNECT-UDP-BIND {}:{}",
253                self.target_host, self.target_port
254            )
255        } else {
256            write!(f, "CONNECT-UDP {}:{}", self.target_host, self.target_port)
257        }
258    }
259}
260
261/// HTTP CONNECT-UDP Response
262///
263/// Represents the relay's response to a CONNECT-UDP request.
264/// Includes the allocated public address for bind requests.
265#[derive(Debug, Clone, PartialEq, Eq)]
266pub struct ConnectUdpResponse {
267    /// HTTP status code (200 = success, 4xx/5xx = error)
268    pub status: u16,
269    /// Public address allocated by relay (for bind requests)
270    pub proxy_public_address: Option<SocketAddr>,
271    /// Human-readable reason phrase
272    pub reason: Option<String>,
273}
274
275impl ConnectUdpResponse {
276    /// HTTP status code for success
277    pub const STATUS_OK: u16 = 200;
278    /// HTTP status code for bad request
279    pub const STATUS_BAD_REQUEST: u16 = 400;
280    /// HTTP status code for forbidden
281    pub const STATUS_FORBIDDEN: u16 = 403;
282    /// HTTP status code for not found
283    pub const STATUS_NOT_FOUND: u16 = 404;
284    /// HTTP status code for service unavailable
285    pub const STATUS_UNAVAILABLE: u16 = 503;
286
287    /// Create a successful response with an allocated public address
288    pub fn success(public_addr: Option<SocketAddr>) -> Self {
289        Self {
290            status: Self::STATUS_OK,
291            proxy_public_address: public_addr,
292            reason: None,
293        }
294    }
295
296    /// Create an error response
297    pub fn error(status: u16, reason: impl Into<String>) -> Self {
298        Self {
299            status,
300            proxy_public_address: None,
301            reason: Some(reason.into()),
302        }
303    }
304
305    /// Create a bad request response
306    pub fn bad_request(reason: impl Into<String>) -> Self {
307        Self::error(Self::STATUS_BAD_REQUEST, reason)
308    }
309
310    /// Create a forbidden response
311    pub fn forbidden(reason: impl Into<String>) -> Self {
312        Self::error(Self::STATUS_FORBIDDEN, reason)
313    }
314
315    /// Create a service unavailable response
316    pub fn unavailable(reason: impl Into<String>) -> Self {
317        Self::error(Self::STATUS_UNAVAILABLE, reason)
318    }
319
320    /// Check if this is a successful response
321    pub fn is_success(&self) -> bool {
322        self.status >= 200 && self.status < 300
323    }
324
325    /// Check if this is an error response
326    pub fn is_error(&self) -> bool {
327        self.status >= 400
328    }
329
330    /// Convert to a Result, extracting the public address on success
331    pub fn into_result(self) -> Result<Option<SocketAddr>, ConnectError> {
332        if self.is_success() {
333            Ok(self.proxy_public_address)
334        } else {
335            Err(ConnectError::Rejected {
336                status: self.status,
337                reason: self.reason.unwrap_or_else(|| "unknown".into()),
338            })
339        }
340    }
341
342    /// Encode the response as wire format
343    ///
344    /// Format: [status (2)] [flags (1)] [addr if present]
345    pub fn encode(&self) -> Bytes {
346        let mut buf = BytesMut::new();
347
348        // Status code
349        buf.put_u16(self.status);
350
351        // Flags: bit 0 = has address, bit 1 = has reason
352        let mut flags: u8 = 0;
353        if self.proxy_public_address.is_some() {
354            flags |= 0x01;
355        }
356        if self.reason.is_some() {
357            flags |= 0x02;
358        }
359        buf.put_u8(flags);
360
361        // Public address if present
362        if let Some(addr) = &self.proxy_public_address {
363            match addr.ip() {
364                IpAddr::V4(v4) => {
365                    buf.put_u8(4);
366                    buf.put_slice(&v4.octets());
367                }
368                IpAddr::V6(v6) => {
369                    buf.put_u8(6);
370                    buf.put_slice(&v6.octets());
371                }
372            }
373            buf.put_u16(addr.port());
374        }
375
376        // Reason if present
377        if let Some(reason) = &self.reason {
378            let reason_bytes = reason.as_bytes();
379            if let Ok(len) = VarInt::from_u64(reason_bytes.len() as u64) {
380                len.encode(&mut buf);
381            }
382            buf.put_slice(reason_bytes);
383        }
384
385        buf.freeze()
386    }
387
388    /// Decode a response from wire format
389    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
390        if buf.remaining() < 3 {
391            return Err(ConnectError::InvalidResponse("buffer too short".into()));
392        }
393
394        let status = buf.get_u16();
395        let flags = buf.get_u8();
396        let has_addr = (flags & 0x01) != 0;
397        let has_reason = (flags & 0x02) != 0;
398
399        let proxy_public_address = if has_addr {
400            if buf.remaining() < 1 {
401                return Err(ConnectError::InvalidResponse("missing IP version".into()));
402            }
403            let ip_version = buf.get_u8();
404            let ip = match ip_version {
405                4 => {
406                    if buf.remaining() < 6 {
407                        return Err(ConnectError::InvalidResponse("missing IPv4 address".into()));
408                    }
409                    let mut octets = [0u8; 4];
410                    buf.copy_to_slice(&mut octets);
411                    IpAddr::V4(Ipv4Addr::from(octets))
412                }
413                6 => {
414                    if buf.remaining() < 18 {
415                        return Err(ConnectError::InvalidResponse("missing IPv6 address".into()));
416                    }
417                    let mut octets = [0u8; 16];
418                    buf.copy_to_slice(&mut octets);
419                    IpAddr::V6(Ipv6Addr::from(octets))
420                }
421                _ => return Err(ConnectError::InvalidResponse("invalid IP version".into())),
422            };
423            let port = buf.get_u16();
424            Some(SocketAddr::new(ip, port))
425        } else {
426            None
427        };
428
429        let reason = if has_reason {
430            let reason_len = VarInt::decode(buf)
431                .map_err(|_| ConnectError::InvalidResponse("invalid reason length".into()))?;
432            let reason_len = reason_len.into_inner() as usize;
433
434            if buf.remaining() < reason_len {
435                return Err(ConnectError::InvalidResponse("missing reason text".into()));
436            }
437
438            let mut reason_bytes = vec![0u8; reason_len];
439            buf.copy_to_slice(&mut reason_bytes);
440            Some(
441                String::from_utf8(reason_bytes)
442                    .map_err(|_| ConnectError::InvalidResponse("invalid UTF-8 in reason".into()))?,
443            )
444        } else {
445            None
446        };
447
448        Ok(Self {
449            status,
450            proxy_public_address,
451            reason,
452        })
453    }
454}
455
456impl fmt::Display for ConnectUdpResponse {
457    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458        write!(f, "{}", self.status)?;
459        if let Some(addr) = &self.proxy_public_address {
460            write!(f, " (public: {})", addr)?;
461        }
462        if let Some(reason) = &self.reason {
463            write!(f, " - {}", reason)?;
464        }
465        Ok(())
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn test_bind_any_request() {
475        let request = ConnectUdpRequest::bind_any();
476        assert!(request.is_bind_request());
477        assert!(request.is_bind_any());
478        assert_eq!(request.target_host, "::");
479        assert_eq!(request.target_port, 0);
480        assert!(request.target_addr().is_none());
481        assert_eq!(request.protocol(), CONNECT_UDP_BIND_PROTOCOL);
482    }
483
484    #[test]
485    fn test_bind_port_request() {
486        let request = ConnectUdpRequest::bind_port(9000);
487        assert!(request.is_bind_request());
488        assert!(!request.is_bind_any()); // Has specific port
489        assert_eq!(request.target_port, 9000);
490    }
491
492    #[test]
493    fn test_target_request() {
494        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), 8080);
495        let request = ConnectUdpRequest::target(addr);
496        assert!(!request.is_bind_request());
497        assert!(!request.is_bind_any());
498        assert_eq!(request.target_addr(), Some(addr));
499        assert_eq!(request.protocol(), CONNECT_UDP_PROTOCOL);
500    }
501
502    #[test]
503    fn test_request_roundtrip() {
504        let original = ConnectUdpRequest::bind_any();
505        let encoded = original.encode();
506        let decoded = ConnectUdpRequest::decode(&mut encoded.clone()).unwrap();
507        assert_eq!(original, decoded);
508
509        let original =
510            ConnectUdpRequest::target(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 443));
511        let encoded = original.encode();
512        let decoded = ConnectUdpRequest::decode(&mut encoded.clone()).unwrap();
513        assert_eq!(original, decoded);
514    }
515
516    #[test]
517    fn test_request_display() {
518        let bind = ConnectUdpRequest::bind_any();
519        assert!(bind.to_string().contains("CONNECT-UDP-BIND"));
520
521        let target = ConnectUdpRequest::target(SocketAddr::new(
522            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
523            80,
524        ));
525        assert!(target.to_string().contains("CONNECT-UDP"));
526        assert!(target.to_string().contains("192.168.1.1:80"));
527    }
528
529    #[test]
530    fn test_success_response() {
531        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000);
532        let response = ConnectUdpResponse::success(Some(addr));
533        assert!(response.is_success());
534        assert!(!response.is_error());
535        assert_eq!(response.proxy_public_address, Some(addr));
536        assert!(response.reason.is_none());
537    }
538
539    #[test]
540    fn test_error_response() {
541        let response = ConnectUdpResponse::bad_request("invalid target");
542        assert!(!response.is_success());
543        assert!(response.is_error());
544        assert_eq!(response.status, 400);
545        assert_eq!(response.reason, Some("invalid target".to_string()));
546    }
547
548    #[test]
549    fn test_response_roundtrip_success() {
550        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000);
551        let original = ConnectUdpResponse::success(Some(addr));
552        let encoded = original.encode();
553        let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
554        assert_eq!(original, decoded);
555    }
556
557    #[test]
558    fn test_response_roundtrip_success_no_addr() {
559        let original = ConnectUdpResponse::success(None);
560        let encoded = original.encode();
561        let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
562        assert_eq!(original, decoded);
563    }
564
565    #[test]
566    fn test_response_roundtrip_error() {
567        let original = ConnectUdpResponse::forbidden("rate limited");
568        let encoded = original.encode();
569        let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
570        assert_eq!(original, decoded);
571    }
572
573    #[test]
574    fn test_response_roundtrip_ipv6() {
575        let addr = SocketAddr::new(
576            IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
577            8443,
578        );
579        let original = ConnectUdpResponse::success(Some(addr));
580        let encoded = original.encode();
581        let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
582        assert_eq!(original, decoded);
583    }
584
585    #[test]
586    fn test_into_result_success() {
587        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 1234);
588        let response = ConnectUdpResponse::success(Some(addr));
589        let result = response.into_result();
590        assert!(result.is_ok());
591        assert_eq!(result.unwrap(), Some(addr));
592    }
593
594    #[test]
595    fn test_into_result_error() {
596        let response = ConnectUdpResponse::unavailable("no capacity");
597        let result = response.into_result();
598        assert!(result.is_err());
599        match result.unwrap_err() {
600            ConnectError::Rejected { status, reason } => {
601                assert_eq!(status, 503);
602                assert_eq!(reason, "no capacity");
603            }
604            _ => panic!("Expected Rejected error"),
605        }
606    }
607
608    #[test]
609    fn test_response_display() {
610        let success = ConnectUdpResponse::success(Some(SocketAddr::new(
611            IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
612            5678,
613        )));
614        let display = success.to_string();
615        assert!(display.contains("200"));
616        assert!(display.contains("1.2.3.4:5678"));
617
618        let error = ConnectUdpResponse::forbidden("rate limit exceeded");
619        let display = error.to_string();
620        assert!(display.contains("403"));
621        assert!(display.contains("rate limit exceeded"));
622    }
623}