1pub use std::net::{IpAddr, SocketAddr};
2use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
3use std::str::FromStr;
4
5use crate::codec::{RawDecode, RawEncode, RawEncodePurpose, RawFixedBytes};
6use crate::*;
7use std::cmp::Ordering;
8
9#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
10pub enum Protocol {
11 Unk = 0,
12 Tcp = 1,
13 Udp = 2,
14}
15
16#[derive(Debug, PartialEq, Eq, Clone, Copy)]
17pub enum EndpointArea {
18 Lan,
19 Default,
20 Wan,
21 Mapped
22}
23
24#[derive(Copy, Clone, Eq)]
25pub struct Endpoint {
26 area: EndpointArea,
27 protocol: Protocol,
28 addr: SocketAddr,
29}
30
31impl Endpoint {
32 pub fn protocol(&self) -> Protocol {
33 self.protocol
34 }
35 pub fn set_protocol(&mut self, p: Protocol) {
36 self.protocol = p
37 }
38
39 pub fn addr(&self) -> &SocketAddr {
40 &self.addr
41 }
42
43 pub fn mut_addr(&mut self) -> &mut SocketAddr {
44 &mut self.addr
45 }
46
47 pub fn is_same_ip_version(&self, other: &Endpoint) -> bool {
48 self.addr.is_ipv4() == other.addr.is_ipv4()
49 }
50
51 pub fn is_same_ip_addr(&self, other: &Endpoint) -> bool {
52 let mut self_ip = self.addr;
53 self_ip.set_port(0);
54 let mut other_ip = other.addr;
55 other_ip.set_port(0);
56 self_ip == other_ip
57 }
58
59 pub fn default_of(ep: &Endpoint) -> Self {
60 match ep.protocol {
61 Protocol::Tcp => Self::default_tcp(ep),
62 Protocol::Udp => Self::default_udp(ep),
63 _ => Self {
64 area: EndpointArea::Lan,
65 protocol: Protocol::Unk,
66 addr: match ep.addr().is_ipv4() {
67 true => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
68 false => SocketAddr::V6(SocketAddrV6::new(
69 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
70 0,
71 0,
72 0,
73 )),
74 },
75 },
76 }
77 }
78
79 pub fn default_tcp(ep: &Endpoint) -> Self {
80 Self {
81 area: EndpointArea::Lan,
82 protocol: Protocol::Tcp,
83 addr: match ep.addr().is_ipv4() {
84 true => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
85 false => SocketAddr::V6(SocketAddrV6::new(
86 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
87 0,
88 0,
89 0,
90 )),
91 },
92 }
93 }
94
95 pub fn default_udp(ep: &Endpoint) -> Self {
96 Self {
97 area: EndpointArea::Lan,
98 protocol: Protocol::Udp,
99 addr: match ep.addr().is_ipv4() {
100 true => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
101 false => SocketAddr::V6(SocketAddrV6::new(
102 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
103 0,
104 0,
105 0,
106 )),
107 },
108 }
109 }
110
111 pub fn is_udp(&self) -> bool {
112 self.protocol == Protocol::Udp
113 }
114 pub fn is_tcp(&self) -> bool {
115 self.protocol == Protocol::Tcp
116 }
117 pub fn is_sys_default(&self) -> bool {
118 self.area == EndpointArea::Default
119 }
120 pub fn is_static_wan(&self) -> bool {
121 self.area == EndpointArea::Wan
122 || self.area == EndpointArea::Mapped
123 }
124
125 pub fn is_mapped_wan(&self) -> bool {
126 self.area == EndpointArea::Mapped
127 }
128
129 pub fn set_area(&mut self, area: EndpointArea) {
130 self.area = area;
131 }
132}
133
134impl Default for Endpoint {
135 fn default() -> Self {
136 Self {
137 area: EndpointArea::Lan,
138 protocol: Protocol::Unk,
139 addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
140 }
141 }
142}
143
144impl From<(Protocol, SocketAddr)> for Endpoint {
145 fn from(ps: (Protocol, SocketAddr)) -> Self {
146 Self {
147 area: EndpointArea::Lan,
148 protocol: ps.0,
149 addr: ps.1,
150 }
151 }
152}
153
154impl From<(Protocol, IpAddr, u16)> for Endpoint {
155 fn from(piu: (Protocol, IpAddr, u16)) -> Self {
156 Self {
157 area: EndpointArea::Lan,
158 protocol: piu.0,
159 addr: SocketAddr::new(piu.1, piu.2),
160 }
161 }
162}
163
164impl ToSocketAddrs for Endpoint {
165 type Iter = <SocketAddr as ToSocketAddrs>::Iter;
166 fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
167 self.addr.to_socket_addrs()
168 }
169}
170
171impl PartialEq for Endpoint {
172 fn eq(&self, other: &Endpoint) -> bool {
173 self.protocol == other.protocol && self.addr == other.addr
174 }
175}
176
177impl PartialOrd for Endpoint {
178 fn partial_cmp(&self, other: &Endpoint) -> Option<std::cmp::Ordering> {
179 use std::cmp::Ordering::*;
180 match self.protocol.partial_cmp(&other.protocol).unwrap() {
181 Equal => match self.addr.ip().partial_cmp(&other.addr().ip()) {
182 None => self.addr.port().partial_cmp(&other.addr.port()),
183 Some(ord) => match ord {
184 Greater => Some(Greater),
185 Less => Some(Less),
186 Equal => self.addr.port().partial_cmp(&other.addr.port()),
187 },
188 },
189 Greater => Some(Greater),
190 Less => Some(Less),
191 }
192 }
193}
194
195impl Ord for Endpoint {
196 fn cmp(&self, other: &Self) -> Ordering {
197 self.partial_cmp(other).unwrap()
198 }
199}
200
201impl std::fmt::Debug for Endpoint {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 write!(f, "{}", self)
204 }
205}
206
207
208impl std::fmt::Display for Endpoint {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 let mut result = String::new();
211
212 result += match self.area {
213 EndpointArea::Lan => "L", EndpointArea::Default => "D", EndpointArea::Wan => "W", EndpointArea::Mapped => "M" };
218
219 result += match self.addr {
220 SocketAddr::V4(_) => "4",
221 SocketAddr::V6(_) => "6",
222 };
223
224 result += match self.protocol {
225 Protocol::Unk => "unk",
226 Protocol::Tcp => "tcp",
227 Protocol::Udp => "udp",
228 };
229
230 result += self.addr.to_string().as_str();
231
232 write!(f, "{}", &result)
233 }
234}
235
236impl FromStr for Endpoint {
237 type Err = BuckyError;
238 fn from_str(s: &str) -> Result<Self, Self::Err> {
239 let area = {
240 match &s[0..1] {
241 "W" => Ok(EndpointArea::Wan),
242 "M" => Ok(EndpointArea::Mapped),
243 "L" => Ok(EndpointArea::Lan),
244 "D" => Ok(EndpointArea::Default),
245 _ => Err(BuckyError::new(
246 BuckyErrorCode::InvalidInput,
247 "invalid endpoint string",
248 )),
249 }
250 }?;
251 let version_str = &s[1..2];
252
253 let protocol = {
254 match &s[2..5] {
255 "tcp" => Ok(Protocol::Tcp),
256 "udp" => Ok(Protocol::Udp),
257 _ => Err(BuckyError::new(
258 BuckyErrorCode::InvalidInput,
259 "invalid endpoint string",
260 )),
261 }
262 }?;
263
264 let addr = SocketAddr::from_str(&s[5..]).map_err(|_| {
265 BuckyError::new(BuckyErrorCode::InvalidInput, "invalid endpoint string")
266 })?;
267 if !(addr.is_ipv4() && version_str.eq("4") || addr.is_ipv6() && version_str.eq("6")) {
268 return Err(BuckyError::new(
269 BuckyErrorCode::InvalidInput,
270 "invalid endpoint string",
271 ));
272 }
273 Ok(Endpoint {
274 area,
275 protocol,
276 addr,
277 })
278 }
279}
280
281pub fn endpoints_to_string(eps: &[Endpoint]) -> String {
282 let mut s = "[".to_string();
283 if eps.len() > 0 {
284 s += eps[0].to_string().as_str();
285 }
286
287 if eps.len() > 1 {
288 for i in 1..eps.len() {
289 s += ",";
290 s += eps[i].to_string().as_str();
291 }
292 }
293 s += "]";
294 s
295}
296
297const ENDPOINT_FLAG_DEFAULT: u8 = 1u8 << 0;
299
300const ENDPOINT_PROTOCOL_UNK: u8 = 0;
301const ENDPOINT_PROTOCOL_TCP: u8 = 1u8 << 1;
302const ENDPOINT_PROTOCOL_UDP: u8 = 1u8 << 2;
303
304const ENDPOINT_IP_VERSION_4: u8 = 1u8 << 3;
305const ENDPOINT_IP_VERSION_6: u8 = 1u8 << 4;
306const ENDPOINT_FLAG_STATIC_WAN: u8 = 1u8 << 6;
307const ENDPOINT_FLAG_SIGNED: u8 = 1u8 << 7;
308
309#[derive(Clone)]
310pub struct SignedEndpoint(Endpoint);
311
312impl From<Endpoint> for SignedEndpoint {
313 fn from(ep: Endpoint) -> Self {
314 Self(ep)
315 }
316}
317
318impl Into<Endpoint> for SignedEndpoint {
319 fn into(self) -> Endpoint {
320 self.0
321 }
322}
323
324impl AsRef<Endpoint> for SignedEndpoint {
325 fn as_ref(&self) -> &Endpoint {
326 &self.0
327 }
328}
329
330impl RawFixedBytes for Endpoint {
331 fn raw_max_bytes() -> Option<usize> {
333 Some(1 + 2 + 16)
334 }
335 fn raw_min_bytes() -> Option<usize> {
336 Some(1 + 2 + 4)
337 }
338}
339
340impl RawFixedBytes for SignedEndpoint {
341 fn raw_max_bytes() -> Option<usize> {
343 Some(1 + 2 + 16)
344 }
345 fn raw_min_bytes() -> Option<usize> {
346 Some(1 + 2 + 4)
347 }
348}
349
350impl Endpoint {
351 fn flags(&self) -> u8 {
352 let mut flags = 0u8;
353 flags |= match self.protocol {
354 Protocol::Tcp => ENDPOINT_PROTOCOL_TCP,
355 Protocol::Unk => ENDPOINT_PROTOCOL_UNK,
356 Protocol::Udp => ENDPOINT_PROTOCOL_UDP,
357 };
358 flags |= match self.is_static_wan() {
359 true => ENDPOINT_FLAG_STATIC_WAN,
360 false => 0,
361 };
362 flags |= match self.is_sys_default() {
363 true => ENDPOINT_FLAG_DEFAULT,
364 false => 0,
365 };
366 flags |= match self.addr {
367 SocketAddr::V4(_) => ENDPOINT_IP_VERSION_4,
368 SocketAddr::V6(_) => ENDPOINT_IP_VERSION_6,
369 };
370 flags
371 }
372
373 fn raw_encode_no_flags<'a>(&self, buf: &'a mut [u8]) -> Result<&'a mut [u8], BuckyError> {
374 buf[0..2].copy_from_slice(&self.addr.port().to_le_bytes()[..]);
375 let buf = &mut buf[2..];
376
377 match self.addr {
378 SocketAddr::V4(ref sock_addr) => {
379 if buf.len() < 4 {
380 let msg = format!(
381 "not enough buffer for encode SocketAddrV4, except={}, got={}",
382 4,
383 buf.len()
384 );
385 error!("{}", msg);
386
387 Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg))
388 } else {
389 unsafe {
390 std::ptr::copy(
391 sock_addr.ip().octets().as_ptr() as *const u8,
392 buf.as_mut_ptr(),
393 4,
394 );
395 }
396 Ok(&mut buf[4..])
397 }
398 }
399 SocketAddr::V6(ref sock_addr) => {
400 if buf.len() < 16 {
401 let msg = format!(
402 "not enough buffer for encode SocketAddrV6, except={}, got={}",
403 16,
404 buf.len()
405 );
406 error!("{}", msg);
407
408 Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg))
409 } else {
410 buf[..16].copy_from_slice(&sock_addr.ip().octets());
411 Ok(&mut buf[16..])
412 }
413 }
414 }
415 }
416
417 fn raw_decode_no_flags<'de>(
418 flags: u8,
419 buf: &'de [u8],
420 ) -> Result<(Self, &'de [u8]), BuckyError> {
421 let protocol = match flags & ENDPOINT_PROTOCOL_TCP {
422 0 => match flags & ENDPOINT_PROTOCOL_UDP {
423 0 => Protocol::Unk,
424 _ => Protocol::Udp,
425 },
426 _ => Protocol::Tcp,
427 };
428
429 let area = if flags & ENDPOINT_FLAG_STATIC_WAN != 0 {
430 EndpointArea::Wan
431 } else if flags & ENDPOINT_FLAG_DEFAULT != 0 {
432 EndpointArea::Default
433 } else {
434 EndpointArea::Lan
435 };
436
437
438 let port = {
439 let mut b = [0u8; 2];
440 b.copy_from_slice(&buf[0..2]);
441 u16::from_le_bytes(b)
442 };
443 let buf = &buf[2..];
444
445 let (addr, buf) = {
446 if flags & ENDPOINT_IP_VERSION_6 != 0 {
447 if buf.len() < 16 {
448 let msg = format!(
449 "not enough buffer for decode EndPoint6, except={}, got={}",
450 16,
451 buf.len()
452 );
453 error!("{}", msg);
454
455 Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg))
456 } else {
457 let mut s: [u8; 16] = [0; 16];
458 s.copy_from_slice(&buf[..16]);
459 let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(s), port, 0, 0));
461 Ok((addr, &buf[16..]))
462 }
463 } else {
464 let addr = SocketAddr::V4(SocketAddrV4::new(
465 Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]),
466 port,
467 ));
468 Ok((addr, &buf[4..]))
469 }
470 }?;
471
472 let ep = Endpoint {
473 area,
474 protocol,
475 addr,
476 };
477 Ok((ep, buf))
478 }
479}
480
481impl RawEncode for Endpoint {
482 fn raw_measure(&self, _purpose: &Option<RawEncodePurpose>) -> Result<usize, BuckyError> {
483 match self.addr {
484 SocketAddr::V4(_) => Ok(1 + 2 + 4),
485 SocketAddr::V6(_) => Ok(1 + 2 + 16),
486 }
487 }
488
489 fn raw_encode<'a>(
490 &self,
491 buf: &'a mut [u8],
492 _purpose: &Option<RawEncodePurpose>,
493 ) -> Result<&'a mut [u8], BuckyError> {
494 let min_bytes = Self::raw_min_bytes().unwrap();
495 if buf.len() < min_bytes {
496 let msg = format!(
497 "not enough buffer for encode Endpoint, min bytes={}, got={}",
498 min_bytes,
499 buf.len()
500 );
501 error!("{}", msg);
502
503 return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
504 }
505
506 buf[0] = self.flags();
507 self.raw_encode_no_flags(&mut buf[1..])
508 }
509}
510
511impl<'de> RawDecode<'de> for Endpoint {
512 fn raw_decode(buf: &'de [u8]) -> Result<(Self, &'de [u8]), BuckyError> {
513 let min_bytes = Self::raw_min_bytes().unwrap();
514 if buf.len() < min_bytes {
515 let msg = format!(
516 "not enough buffer for decode Endpoint, min bytes={}, got={}",
517 min_bytes,
518 buf.len()
519 );
520 error!("{}", msg);
521
522 return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
523 }
524 let flags = buf[0];
525 Self::raw_decode_no_flags(flags, &buf[1..])
526 }
527}
528
529impl RawEncode for SignedEndpoint {
530 fn raw_measure(&self, purpose: &Option<RawEncodePurpose>) -> Result<usize, BuckyError> {
531 self.0.raw_measure(purpose)
532 }
533
534 fn raw_encode<'a>(
535 &self,
536 buf: &'a mut [u8],
537 purpose: &Option<RawEncodePurpose>,
538 ) -> Result<&'a mut [u8], BuckyError> {
539 let bytes = self.raw_measure(purpose)?;
540 if buf.len() < bytes {
541 let msg = format!(
542 "not enough buffer for encode SignedEndpoint, except={}, got={}",
543 bytes,
544 buf.len()
545 );
546 error!("{}", msg);
547
548 return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
549 }
550
551 buf[0] = self.0.flags() | ENDPOINT_FLAG_SIGNED;
552 self.0.raw_encode_no_flags(&mut buf[1..])
553 }
554}
555
556impl<'de> RawDecode<'de> for SignedEndpoint {
557 fn raw_decode(buf: &'de [u8]) -> Result<(Self, &'de [u8]), BuckyError> {
558 let min_bytes = Self::raw_min_bytes().unwrap();
559 if buf.len() < min_bytes {
560 let msg = format!(
561 "not enough buffer for decode SignedEndpoint, min bytes={}, got={}",
562 min_bytes,
563 buf.len()
564 );
565 error!("{}", msg);
566
567 return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
568 }
569 let flags = buf[0];
570 if flags & ENDPOINT_FLAG_SIGNED == 0 {
571 return Err(BuckyError::new(
572 BuckyErrorCode::InvalidParam,
573 "without sign flag",
574 ));
575 }
576 let (ep, buf) = Endpoint::raw_decode_no_flags(flags, &buf[1..])?;
577 Ok((SignedEndpoint(ep), buf))
578 }
579}
580
581#[cfg(test)]
582mod test {
583 use crate::*;
584 use async_std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
585 use std::convert::From;
586 #[test]
589 fn test_codec() {
590 let ep = Endpoint::default();
591 let v = ep.to_vec().unwrap();
592 let ep2 = Endpoint::clone_from_slice(&v).unwrap();
593 assert_eq!(ep, ep2);
594
595 let ep: Endpoint = (
596 Protocol::Tcp,
597 SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(127, 11, 22, 33), 4)),
598 )
599 .into();
600 let v = ep.to_vec().unwrap();
601 let ep2 = Endpoint::clone_from_slice(&v).unwrap();
602 assert_eq!(ep, ep2);
603 }
604 #[test]
605 fn endpoint() {
606 let ep: Endpoint = (
607 Protocol::Tcp,
608 SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(127, 1, 2, 3), 4)),
609 )
610 .into();
611 println!("{}", ep);
612
613 let ep: Endpoint = (
619 Protocol::Tcp,
620 SocketAddr::from(SocketAddrV6::new(
621 Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8),
622 9,
623 10,
624 11,
625 )),
626 )
627 .into();
628 println!("{}", ep);
629 }
634}