1use bytes::{Buf, BufMut, Bytes, BytesMut};
59use std::fmt;
60use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
61use thiserror::Error;
62
63use crate::VarInt;
64use crate::coding::{self, Codec};
65
66pub const CONNECT_UDP_PROTOCOL: &str = "connect-udp";
68
69pub const CONNECT_UDP_BIND_PROTOCOL: &str = "connect-udp-bind";
71
72pub const BIND_ANY_HOST: &str = "::";
74
75pub const BIND_ANY_PORT: u16 = 0;
77
78#[derive(Debug, Error)]
80pub enum ConnectError {
81 #[error("invalid request: {0}")]
83 InvalidRequest(String),
84
85 #[error("invalid response: {0}")]
87 InvalidResponse(String),
88
89 #[error("rejected: status {status}, reason: {reason}")]
91 Rejected {
92 status: u16,
94 reason: String,
96 },
97
98 #[error("codec error")]
100 Codec,
101
102 #[error("connection failed: {0}")]
104 ConnectionFailed(String),
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
113pub struct ConnectUdpRequest {
114 pub target_host: String,
116 pub target_port: u16,
118 pub connect_udp_bind: bool,
120}
121
122impl ConnectUdpRequest {
123 pub fn bind_any() -> Self {
128 Self {
129 target_host: BIND_ANY_HOST.to_string(),
130 target_port: BIND_ANY_PORT,
131 connect_udp_bind: true,
132 }
133 }
134
135 pub fn bind_port(port: u16) -> Self {
140 Self {
141 target_host: BIND_ANY_HOST.to_string(),
142 target_port: port,
143 connect_udp_bind: true,
144 }
145 }
146
147 pub fn target(addr: SocketAddr) -> Self {
152 Self {
153 target_host: addr.ip().to_string(),
154 target_port: addr.port(),
155 connect_udp_bind: false,
156 }
157 }
158
159 pub fn is_bind_request(&self) -> bool {
161 self.connect_udp_bind
162 }
163
164 pub fn is_bind_any(&self) -> bool {
166 self.connect_udp_bind
167 && (self.target_host == BIND_ANY_HOST || self.target_host == "0.0.0.0")
168 && self.target_port == BIND_ANY_PORT
169 }
170
171 pub fn target_addr(&self) -> Option<SocketAddr> {
173 if self.is_bind_request() {
174 return None;
175 }
176
177 let ip: IpAddr = self.target_host.parse().ok()?;
178 Some(SocketAddr::new(ip, self.target_port))
179 }
180
181 pub fn protocol(&self) -> &'static str {
183 if self.connect_udp_bind {
184 CONNECT_UDP_BIND_PROTOCOL
185 } else {
186 CONNECT_UDP_PROTOCOL
187 }
188 }
189
190 pub fn encode(&self) -> Bytes {
194 let mut buf = BytesMut::new();
195
196 let flags: u8 = if self.connect_udp_bind { 0x01 } else { 0x00 };
198 buf.put_u8(flags);
199
200 let host_bytes = self.target_host.as_bytes();
202 if let Ok(len) = VarInt::from_u64(host_bytes.len() as u64) {
203 len.encode(&mut buf);
204 }
205 buf.put_slice(host_bytes);
206
207 buf.put_u16(self.target_port);
209
210 buf.freeze()
211 }
212
213 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
215 if buf.remaining() < 1 {
216 return Err(ConnectError::InvalidRequest("buffer too short".into()));
217 }
218
219 let flags = buf.get_u8();
220 let connect_udp_bind = (flags & 0x01) != 0;
221
222 let host_len = VarInt::decode(buf)
223 .map_err(|_| ConnectError::InvalidRequest("invalid host length".into()))?;
224 let host_len = host_len.into_inner() as usize;
225
226 if buf.remaining() < host_len + 2 {
227 return Err(ConnectError::InvalidRequest(
228 "buffer too short for host".into(),
229 ));
230 }
231
232 let mut host_bytes = vec![0u8; host_len];
233 buf.copy_to_slice(&mut host_bytes);
234 let target_host = String::from_utf8(host_bytes)
235 .map_err(|_| ConnectError::InvalidRequest("invalid UTF-8 in host".into()))?;
236
237 let target_port = buf.get_u16();
238
239 Ok(Self {
240 target_host,
241 target_port,
242 connect_udp_bind,
243 })
244 }
245}
246
247impl fmt::Display for ConnectUdpRequest {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 if self.is_bind_request() {
250 write!(
251 f,
252 "CONNECT-UDP-BIND {}:{}",
253 self.target_host, self.target_port
254 )
255 } else {
256 write!(f, "CONNECT-UDP {}:{}", self.target_host, self.target_port)
257 }
258 }
259}
260
261#[derive(Debug, Clone, PartialEq, Eq)]
266pub struct ConnectUdpResponse {
267 pub status: u16,
269 pub proxy_public_address: Option<SocketAddr>,
271 pub reason: Option<String>,
273}
274
275impl ConnectUdpResponse {
276 pub const STATUS_OK: u16 = 200;
278 pub const STATUS_BAD_REQUEST: u16 = 400;
280 pub const STATUS_FORBIDDEN: u16 = 403;
282 pub const STATUS_NOT_FOUND: u16 = 404;
284 pub const STATUS_UNAVAILABLE: u16 = 503;
286
287 pub fn success(public_addr: Option<SocketAddr>) -> Self {
289 Self {
290 status: Self::STATUS_OK,
291 proxy_public_address: public_addr,
292 reason: None,
293 }
294 }
295
296 pub fn error(status: u16, reason: impl Into<String>) -> Self {
298 Self {
299 status,
300 proxy_public_address: None,
301 reason: Some(reason.into()),
302 }
303 }
304
305 pub fn bad_request(reason: impl Into<String>) -> Self {
307 Self::error(Self::STATUS_BAD_REQUEST, reason)
308 }
309
310 pub fn forbidden(reason: impl Into<String>) -> Self {
312 Self::error(Self::STATUS_FORBIDDEN, reason)
313 }
314
315 pub fn unavailable(reason: impl Into<String>) -> Self {
317 Self::error(Self::STATUS_UNAVAILABLE, reason)
318 }
319
320 pub fn is_success(&self) -> bool {
322 self.status >= 200 && self.status < 300
323 }
324
325 pub fn is_error(&self) -> bool {
327 self.status >= 400
328 }
329
330 pub fn into_result(self) -> Result<Option<SocketAddr>, ConnectError> {
332 if self.is_success() {
333 Ok(self.proxy_public_address)
334 } else {
335 Err(ConnectError::Rejected {
336 status: self.status,
337 reason: self.reason.unwrap_or_else(|| "unknown".into()),
338 })
339 }
340 }
341
342 pub fn encode(&self) -> Bytes {
346 let mut buf = BytesMut::new();
347
348 buf.put_u16(self.status);
350
351 let mut flags: u8 = 0;
353 if self.proxy_public_address.is_some() {
354 flags |= 0x01;
355 }
356 if self.reason.is_some() {
357 flags |= 0x02;
358 }
359 buf.put_u8(flags);
360
361 if let Some(addr) = &self.proxy_public_address {
363 match addr.ip() {
364 IpAddr::V4(v4) => {
365 buf.put_u8(4);
366 buf.put_slice(&v4.octets());
367 }
368 IpAddr::V6(v6) => {
369 buf.put_u8(6);
370 buf.put_slice(&v6.octets());
371 }
372 }
373 buf.put_u16(addr.port());
374 }
375
376 if let Some(reason) = &self.reason {
378 let reason_bytes = reason.as_bytes();
379 if let Ok(len) = VarInt::from_u64(reason_bytes.len() as u64) {
380 len.encode(&mut buf);
381 }
382 buf.put_slice(reason_bytes);
383 }
384
385 buf.freeze()
386 }
387
388 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
390 if buf.remaining() < 3 {
391 return Err(ConnectError::InvalidResponse("buffer too short".into()));
392 }
393
394 let status = buf.get_u16();
395 let flags = buf.get_u8();
396 let has_addr = (flags & 0x01) != 0;
397 let has_reason = (flags & 0x02) != 0;
398
399 let proxy_public_address = if has_addr {
400 if buf.remaining() < 1 {
401 return Err(ConnectError::InvalidResponse("missing IP version".into()));
402 }
403 let ip_version = buf.get_u8();
404 let ip = match ip_version {
405 4 => {
406 if buf.remaining() < 6 {
407 return Err(ConnectError::InvalidResponse("missing IPv4 address".into()));
408 }
409 let mut octets = [0u8; 4];
410 buf.copy_to_slice(&mut octets);
411 IpAddr::V4(Ipv4Addr::from(octets))
412 }
413 6 => {
414 if buf.remaining() < 18 {
415 return Err(ConnectError::InvalidResponse("missing IPv6 address".into()));
416 }
417 let mut octets = [0u8; 16];
418 buf.copy_to_slice(&mut octets);
419 IpAddr::V6(Ipv6Addr::from(octets))
420 }
421 _ => return Err(ConnectError::InvalidResponse("invalid IP version".into())),
422 };
423 let port = buf.get_u16();
424 Some(SocketAddr::new(ip, port))
425 } else {
426 None
427 };
428
429 let reason = if has_reason {
430 let reason_len = VarInt::decode(buf)
431 .map_err(|_| ConnectError::InvalidResponse("invalid reason length".into()))?;
432 let reason_len = reason_len.into_inner() as usize;
433
434 if buf.remaining() < reason_len {
435 return Err(ConnectError::InvalidResponse("missing reason text".into()));
436 }
437
438 let mut reason_bytes = vec![0u8; reason_len];
439 buf.copy_to_slice(&mut reason_bytes);
440 Some(
441 String::from_utf8(reason_bytes)
442 .map_err(|_| ConnectError::InvalidResponse("invalid UTF-8 in reason".into()))?,
443 )
444 } else {
445 None
446 };
447
448 Ok(Self {
449 status,
450 proxy_public_address,
451 reason,
452 })
453 }
454}
455
456impl fmt::Display for ConnectUdpResponse {
457 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458 write!(f, "{}", self.status)?;
459 if let Some(addr) = &self.proxy_public_address {
460 write!(f, " (public: {})", addr)?;
461 }
462 if let Some(reason) = &self.reason {
463 write!(f, " - {}", reason)?;
464 }
465 Ok(())
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn test_bind_any_request() {
475 let request = ConnectUdpRequest::bind_any();
476 assert!(request.is_bind_request());
477 assert!(request.is_bind_any());
478 assert_eq!(request.target_host, "::");
479 assert_eq!(request.target_port, 0);
480 assert!(request.target_addr().is_none());
481 assert_eq!(request.protocol(), CONNECT_UDP_BIND_PROTOCOL);
482 }
483
484 #[test]
485 fn test_bind_port_request() {
486 let request = ConnectUdpRequest::bind_port(9000);
487 assert!(request.is_bind_request());
488 assert!(!request.is_bind_any()); assert_eq!(request.target_port, 9000);
490 }
491
492 #[test]
493 fn test_target_request() {
494 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)), 8080);
495 let request = ConnectUdpRequest::target(addr);
496 assert!(!request.is_bind_request());
497 assert!(!request.is_bind_any());
498 assert_eq!(request.target_addr(), Some(addr));
499 assert_eq!(request.protocol(), CONNECT_UDP_PROTOCOL);
500 }
501
502 #[test]
503 fn test_request_roundtrip() {
504 let original = ConnectUdpRequest::bind_any();
505 let encoded = original.encode();
506 let decoded = ConnectUdpRequest::decode(&mut encoded.clone()).unwrap();
507 assert_eq!(original, decoded);
508
509 let original =
510 ConnectUdpRequest::target(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 443));
511 let encoded = original.encode();
512 let decoded = ConnectUdpRequest::decode(&mut encoded.clone()).unwrap();
513 assert_eq!(original, decoded);
514 }
515
516 #[test]
517 fn test_request_display() {
518 let bind = ConnectUdpRequest::bind_any();
519 assert!(bind.to_string().contains("CONNECT-UDP-BIND"));
520
521 let target = ConnectUdpRequest::target(SocketAddr::new(
522 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
523 80,
524 ));
525 assert!(target.to_string().contains("CONNECT-UDP"));
526 assert!(target.to_string().contains("192.168.1.1:80"));
527 }
528
529 #[test]
530 fn test_success_response() {
531 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000);
532 let response = ConnectUdpResponse::success(Some(addr));
533 assert!(response.is_success());
534 assert!(!response.is_error());
535 assert_eq!(response.proxy_public_address, Some(addr));
536 assert!(response.reason.is_none());
537 }
538
539 #[test]
540 fn test_error_response() {
541 let response = ConnectUdpResponse::bad_request("invalid target");
542 assert!(!response.is_success());
543 assert!(response.is_error());
544 assert_eq!(response.status, 400);
545 assert_eq!(response.reason, Some("invalid target".to_string()));
546 }
547
548 #[test]
549 fn test_response_roundtrip_success() {
550 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)), 9000);
551 let original = ConnectUdpResponse::success(Some(addr));
552 let encoded = original.encode();
553 let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
554 assert_eq!(original, decoded);
555 }
556
557 #[test]
558 fn test_response_roundtrip_success_no_addr() {
559 let original = ConnectUdpResponse::success(None);
560 let encoded = original.encode();
561 let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
562 assert_eq!(original, decoded);
563 }
564
565 #[test]
566 fn test_response_roundtrip_error() {
567 let original = ConnectUdpResponse::forbidden("rate limited");
568 let encoded = original.encode();
569 let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
570 assert_eq!(original, decoded);
571 }
572
573 #[test]
574 fn test_response_roundtrip_ipv6() {
575 let addr = SocketAddr::new(
576 IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
577 8443,
578 );
579 let original = ConnectUdpResponse::success(Some(addr));
580 let encoded = original.encode();
581 let decoded = ConnectUdpResponse::decode(&mut encoded.clone()).unwrap();
582 assert_eq!(original, decoded);
583 }
584
585 #[test]
586 fn test_into_result_success() {
587 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 1234);
588 let response = ConnectUdpResponse::success(Some(addr));
589 let result = response.into_result();
590 assert!(result.is_ok());
591 assert_eq!(result.unwrap(), Some(addr));
592 }
593
594 #[test]
595 fn test_into_result_error() {
596 let response = ConnectUdpResponse::unavailable("no capacity");
597 let result = response.into_result();
598 assert!(result.is_err());
599 match result.unwrap_err() {
600 ConnectError::Rejected { status, reason } => {
601 assert_eq!(status, 503);
602 assert_eq!(reason, "no capacity");
603 }
604 _ => panic!("Expected Rejected error"),
605 }
606 }
607
608 #[test]
609 fn test_response_display() {
610 let success = ConnectUdpResponse::success(Some(SocketAddr::new(
611 IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
612 5678,
613 )));
614 let display = success.to_string();
615 assert!(display.contains("200"));
616 assert!(display.contains("1.2.3.4:5678"));
617
618 let error = ConnectUdpResponse::forbidden("rate limit exceeded");
619 let display = error.to_string();
620 assert!(display.contains("403"));
621 assert!(display.contains("rate limit exceeded"));
622 }
623}