use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::fmt;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use thiserror::Error;
use crate::VarInt;
use crate::coding::Codec;
pub const CONNECT_UDP_PROTOCOL: &str = "connect-udp";
pub const CONNECT_UDP_BIND_PROTOCOL: &str = "connect-udp-bind";
pub const BIND_ANY_HOST: &str = "::";
pub const BIND_ANY_PORT: u16 = 0;
#[derive(Debug, Error)]
pub enum ConnectError {
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("invalid response: {0}")]
InvalidResponse(String),
#[error("rejected: status {status}, reason: {reason}")]
Rejected {
status: u16,
reason: String,
},
#[error("codec error")]
Codec,
#[error("connection failed: {0}")]
ConnectionFailed(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConnectUdpRequest {
pub target_host: String,
pub target_port: u16,
pub connect_udp_bind: bool,
}
impl ConnectUdpRequest {
pub fn bind_any() -> Self {
Self {
target_host: BIND_ANY_HOST.to_string(),
target_port: BIND_ANY_PORT,
connect_udp_bind: true,
}
}
pub fn bind_port(port: u16) -> Self {
Self {
target_host: BIND_ANY_HOST.to_string(),
target_port: port,
connect_udp_bind: true,
}
}
pub fn target(addr: SocketAddr) -> Self {
Self {
target_host: addr.ip().to_string(),
target_port: addr.port(),
connect_udp_bind: false,
}
}
pub fn is_bind_request(&self) -> bool {
self.connect_udp_bind
}
pub fn is_bind_any(&self) -> bool {
self.connect_udp_bind
&& (self.target_host == BIND_ANY_HOST || self.target_host == "0.0.0.0")
&& self.target_port == BIND_ANY_PORT
}
pub fn target_addr(&self) -> Option<SocketAddr> {
if self.is_bind_request() {
return None;
}
let ip: IpAddr = self.target_host.parse().ok()?;
Some(SocketAddr::new(ip, self.target_port))
}
pub fn target_address(&self) -> Option<SocketAddr> {
self.target_addr()
}
pub fn protocol(&self) -> &'static str {
if self.connect_udp_bind {
CONNECT_UDP_BIND_PROTOCOL
} else {
CONNECT_UDP_PROTOCOL
}
}
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::new();
let flags: u8 = if self.connect_udp_bind { 0x01 } else { 0x00 };
buf.put_u8(flags);
let host_bytes = self.target_host.as_bytes();
if let Ok(len) = VarInt::from_u64(host_bytes.len() as u64) {
len.encode(&mut buf);
}
buf.put_slice(host_bytes);
buf.put_u16(self.target_port);
buf.freeze()
}
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
if buf.remaining() < 1 {
return Err(ConnectError::InvalidRequest("buffer too short".into()));
}
let flags = buf.get_u8();
let connect_udp_bind = (flags & 0x01) != 0;
let host_len = VarInt::decode(buf)
.map_err(|_| ConnectError::InvalidRequest("invalid host length".into()))?;
let host_len = host_len.into_inner() as usize;
if buf.remaining() < host_len + 2 {
return Err(ConnectError::InvalidRequest(
"buffer too short for host".into(),
));
}
let mut host_bytes = vec![0u8; host_len];
buf.copy_to_slice(&mut host_bytes);
let target_host = String::from_utf8(host_bytes)
.map_err(|_| ConnectError::InvalidRequest("invalid UTF-8 in host".into()))?;
let target_port = buf.get_u16();
Ok(Self {
target_host,
target_port,
connect_udp_bind,
})
}
}
impl fmt::Display for ConnectUdpRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_bind_request() {
write!(
f,
"CONNECT-UDP-BIND {}:{}",
self.target_host, self.target_port
)
} else {
write!(f, "CONNECT-UDP {}:{}", self.target_host, self.target_port)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConnectUdpResponse {
pub status: u16,
pub proxy_public_address: Option<SocketAddr>,
pub reason: Option<String>,
}
impl ConnectUdpResponse {
pub const STATUS_OK: u16 = 200;
pub const STATUS_BAD_REQUEST: u16 = 400;
pub const STATUS_FORBIDDEN: u16 = 403;
pub const STATUS_NOT_FOUND: u16 = 404;
pub const STATUS_UNAVAILABLE: u16 = 503;
pub fn success(public_addr: Option<SocketAddr>) -> Self {
Self {
status: Self::STATUS_OK,
proxy_public_address: public_addr,
reason: None,
}
}
pub fn error(status: u16, reason: impl Into<String>) -> Self {
Self {
status,
proxy_public_address: None,
reason: Some(reason.into()),
}
}
pub fn bad_request(reason: impl Into<String>) -> Self {
Self::error(Self::STATUS_BAD_REQUEST, reason)
}
pub fn forbidden(reason: impl Into<String>) -> Self {
Self::error(Self::STATUS_FORBIDDEN, reason)
}
pub fn unavailable(reason: impl Into<String>) -> Self {
Self::error(Self::STATUS_UNAVAILABLE, reason)
}
pub fn is_success(&self) -> bool {
self.status >= 200 && self.status < 300
}
pub fn is_error(&self) -> bool {
self.status >= 400
}
pub fn into_result(self) -> Result<Option<SocketAddr>, ConnectError> {
if self.is_success() {
Ok(self.proxy_public_address)
} else {
Err(ConnectError::Rejected {
status: self.status,
reason: self.reason.unwrap_or_else(|| "unknown".into()),
})
}
}
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::new();
buf.put_u16(self.status);
let mut flags: u8 = 0;
if self.proxy_public_address.is_some() {
flags |= 0x01;
}
if self.reason.is_some() {
flags |= 0x02;
}
buf.put_u8(flags);
if let Some(addr) = &self.proxy_public_address {
match addr.ip() {
IpAddr::V4(v4) => {
buf.put_u8(4);
buf.put_slice(&v4.octets());
}
IpAddr::V6(v6) => {
buf.put_u8(6);
buf.put_slice(&v6.octets());
}
}
buf.put_u16(addr.port());
}
if let Some(reason) = &self.reason {
let reason_bytes = reason.as_bytes();
if let Ok(len) = VarInt::from_u64(reason_bytes.len() as u64) {
len.encode(&mut buf);
}
buf.put_slice(reason_bytes);
}
buf.freeze()
}
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
if buf.remaining() < 3 {
return Err(ConnectError::InvalidResponse("buffer too short".into()));
}
let status = buf.get_u16();
let flags = buf.get_u8();
let has_addr = (flags & 0x01) != 0;
let has_reason = (flags & 0x02) != 0;
let proxy_public_address = if has_addr {
if buf.remaining() < 1 {
return Err(ConnectError::InvalidResponse("missing IP version".into()));
}
let ip_version = buf.get_u8();
let ip = match ip_version {
4 => {
if buf.remaining() < 6 {
return Err(ConnectError::InvalidResponse("missing IPv4 address".into()));
}
let mut octets = [0u8; 4];
buf.copy_to_slice(&mut octets);
IpAddr::V4(Ipv4Addr::from(octets))
}
6 => {
if buf.remaining() < 18 {
return Err(ConnectError::InvalidResponse("missing IPv6 address".into()));
}
let mut octets = [0u8; 16];
buf.copy_to_slice(&mut octets);
IpAddr::V6(Ipv6Addr::from(octets))
}
_ => return Err(ConnectError::InvalidResponse("invalid IP version".into())),
};
let port = buf.get_u16();
Some(SocketAddr::new(ip, port))
} else {
None
};
let reason = if has_reason {
let reason_len = VarInt::decode(buf)
.map_err(|_| ConnectError::InvalidResponse("invalid reason length".into()))?;
let reason_len = reason_len.into_inner() as usize;
if buf.remaining() < reason_len {
return Err(ConnectError::InvalidResponse("missing reason text".into()));
}
let mut reason_bytes = vec![0u8; reason_len];
buf.copy_to_slice(&mut reason_bytes);
Some(
String::from_utf8(reason_bytes)
.map_err(|_| ConnectError::InvalidResponse("invalid UTF-8 in reason".into()))?,
)
} else {
None
};
Ok(Self {
status,
proxy_public_address,
reason,
})
}
}
impl fmt::Display for ConnectUdpResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.status)?;
if let Some(addr) = &self.proxy_public_address {
write!(f, " (public: {})", addr)?;
}
if let Some(reason) = &self.reason {
write!(f, " - {}", reason)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bind_any_request() {
let request = ConnectUdpRequest::bind_any();
assert!(request.is_bind_request());
assert!(request.is_bind_any());
assert_eq!(request.target_host, "::");
assert_eq!(request.target_port, 0);
assert!(request.target_addr().is_none());
assert_eq!(request.protocol(), CONNECT_UDP_BIND_PROTOCOL);
}
#[test]
fn test_bind_port_request() {
let request = ConnectUdpRequest::bind_port(9000);
assert!(request.is_bind_request());
assert!(!request.is_bind_any()); assert_eq!(request.target_port, 9000);
}
#[test]
fn test_target_request() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), 8080);
let request = ConnectUdpRequest::target(addr);
assert!(!request.is_bind_request());
assert!(!request.is_bind_any());
assert_eq!(request.target_addr(), Some(addr));
assert_eq!(request.protocol(), CONNECT_UDP_PROTOCOL);
}
#[test]
fn test_request_roundtrip() {
let original = ConnectUdpRequest::bind_any();
let encoded = original.encode();
let decoded = ConnectUdpRequest::decode(&mut encoded.clone()).unwrap();
assert_eq!(original, decoded);
let original =
ConnectUdpRequest::target(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 443));
let encoded = original.encode();
let decoded = ConnectUdpRequest::decode(&mut encoded.clone()).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_request_display() {
let bind = ConnectUdpRequest::bind_any();
assert!(bind.to_string().contains("CONNECT-UDP-BIND"));
let target = ConnectUdpRequest::target(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
80,
));
assert!(target.to_string().contains("CONNECT-UDP"));
assert!(target.to_string().contains("192.168.1.1:80"));
}
#[test]
fn test_success_response() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000);
let response = ConnectUdpResponse::success(Some(addr));
assert!(response.is_success());
assert!(!response.is_error());
assert_eq!(response.proxy_public_address, Some(addr));
assert!(response.reason.is_none());
}
#[test]
fn test_error_response() {
let response = ConnectUdpResponse::bad_request("invalid target");
assert!(!response.is_success());
assert!(response.is_error());
assert_eq!(response.status, 400);
assert_eq!(response.reason, Some("invalid target".to_string()));
}
#[test]
fn test_response_roundtrip_success() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000);
let original = ConnectUdpResponse::success(Some(addr));
let encoded = original.encode();
let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_response_roundtrip_success_no_addr() {
let original = ConnectUdpResponse::success(None);
let encoded = original.encode();
let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_response_roundtrip_error() {
let original = ConnectUdpResponse::forbidden("rate limited");
let encoded = original.encode();
let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_response_roundtrip_ipv6() {
let addr = SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
8443,
);
let original = ConnectUdpResponse::success(Some(addr));
let encoded = original.encode();
let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_into_result_success() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 1234);
let response = ConnectUdpResponse::success(Some(addr));
let result = response.into_result();
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some(addr));
}
#[test]
fn test_into_result_error() {
let response = ConnectUdpResponse::unavailable("no capacity");
let result = response.into_result();
assert!(result.is_err());
match result.unwrap_err() {
ConnectError::Rejected { status, reason } => {
assert_eq!(status, 503);
assert_eq!(reason, "no capacity");
}
_ => panic!("Expected Rejected error"),
}
}
#[test]
fn test_response_display() {
let success = ConnectUdpResponse::success(Some(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
5678,
)));
let display = success.to_string();
assert!(display.contains("200"));
assert!(display.contains("1.2.3.4:5678"));
let error = ConnectUdpResponse::forbidden("rate limit exceeded");
let display = error.to_string();
assert!(display.contains("403"));
assert!(display.contains("rate limit exceeded"));
}
}