bucky_objects/
endpoint.rs

1pub use std::net::{IpAddr, SocketAddr};
2use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
3use std::str::FromStr;
4
5use crate::*;
6use std::cmp::Ordering;
7
8#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
9pub enum Protocol {
10    Unk = 0,
11    Tcp = 1,
12    Udp = 2,
13}
14
15#[derive(Debug, PartialEq, Eq, Clone, Copy)]
16pub enum EndpointArea {
17    Lan,
18    Default,
19    Wan,
20    Mapped
21}
22
23#[derive(Copy, Clone, Eq)]
24pub struct Endpoint {
25    area: EndpointArea,
26    protocol: Protocol,
27    addr: SocketAddr,
28}
29
30impl Endpoint {
31    pub fn protocol(&self) -> Protocol {
32        self.protocol
33    }
34    pub fn set_protocol(&mut self, p: Protocol) {
35        self.protocol = p
36    }
37
38    pub fn addr(&self) -> &SocketAddr {
39        &self.addr
40    }
41
42    pub fn mut_addr(&mut self) -> &mut SocketAddr {
43        &mut self.addr
44    }
45
46    pub fn is_same_ip_version(&self, other: &Endpoint) -> bool {
47        self.addr.is_ipv4() == other.addr.is_ipv4()
48    }
49
50    pub fn is_same_ip_addr(&self, other: &Endpoint) -> bool {
51        let mut self_ip = self.addr;
52        self_ip.set_port(0);
53        let mut other_ip = other.addr;
54        other_ip.set_port(0);
55        self_ip == other_ip
56    }
57
58    pub fn default_of(ep: &Endpoint) -> Self {
59        match ep.protocol {
60            Protocol::Tcp => Self::default_tcp(ep),
61            Protocol::Udp => Self::default_udp(ep),
62            _ => Self {
63                area: EndpointArea::Lan,
64                protocol: Protocol::Unk,
65                addr: match ep.addr().is_ipv4() {
66                    true => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
67                    false => SocketAddr::V6(SocketAddrV6::new(
68                        Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
69                        0,
70                        0,
71                        0,
72                    )),
73                },
74            },
75        }
76    }
77
78    pub fn default_tcp(ep: &Endpoint) -> Self {
79        Self {
80            area: EndpointArea::Lan,
81            protocol: Protocol::Tcp,
82            addr: match ep.addr().is_ipv4() {
83                true => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
84                false => SocketAddr::V6(SocketAddrV6::new(
85                    Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
86                    0,
87                    0,
88                    0,
89                )),
90            },
91        }
92    }
93
94    pub fn default_udp(ep: &Endpoint) -> Self {
95        Self {
96            area: EndpointArea::Lan,
97            protocol: Protocol::Udp,
98            addr: match ep.addr().is_ipv4() {
99                true => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
100                false => SocketAddr::V6(SocketAddrV6::new(
101                    Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
102                    0,
103                    0,
104                    0,
105                )),
106            },
107        }
108    }
109
110    pub fn is_udp(&self) -> bool {
111        self.protocol == Protocol::Udp
112    }
113    pub fn is_tcp(&self) -> bool {
114        self.protocol == Protocol::Tcp
115    }
116    pub fn is_sys_default(&self) -> bool {
117        self.area == EndpointArea::Default
118    }
119    pub fn is_static_wan(&self) -> bool {
120        self.area == EndpointArea::Wan
121            || self.area == EndpointArea::Mapped
122    }
123
124    pub fn is_mapped_wan(&self) -> bool {
125        self.area == EndpointArea::Mapped
126    }
127
128    pub fn set_area(&mut self, area: EndpointArea) {
129        self.area = area;
130    }
131}
132
133impl Default for Endpoint {
134    fn default() -> Self {
135        Self {
136            area: EndpointArea::Lan,
137            protocol: Protocol::Unk,
138            addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
139        }
140    }
141}
142
143impl From<(Protocol, SocketAddr)> for Endpoint {
144    fn from(ps: (Protocol, SocketAddr)) -> Self {
145        Self {
146            area: EndpointArea::Lan,
147            protocol: ps.0,
148            addr: ps.1,
149        }
150    }
151}
152
153impl From<(Protocol, IpAddr, u16)> for Endpoint {
154    fn from(piu: (Protocol, IpAddr, u16)) -> Self {
155        Self {
156            area: EndpointArea::Lan,
157            protocol: piu.0,
158            addr: SocketAddr::new(piu.1, piu.2),
159        }
160    }
161}
162
163impl ToSocketAddrs for Endpoint {
164    type Iter = <SocketAddr as ToSocketAddrs>::Iter;
165    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
166        self.addr.to_socket_addrs()
167    }
168}
169
170impl PartialEq for Endpoint {
171    fn eq(&self, other: &Endpoint) -> bool {
172        self.protocol == other.protocol && self.addr == other.addr
173    }
174}
175
176impl PartialOrd for Endpoint {
177    fn partial_cmp(&self, other: &Endpoint) -> Option<std::cmp::Ordering> {
178        use std::cmp::Ordering::*;
179        match self.protocol.partial_cmp(&other.protocol).unwrap() {
180            Equal => match self.addr.ip().partial_cmp(&other.addr().ip()) {
181                None => self.addr.port().partial_cmp(&other.addr.port()),
182                Some(ord) => match ord {
183                    Greater => Some(Greater),
184                    Less => Some(Less),
185                    Equal => self.addr.port().partial_cmp(&other.addr.port()),
186                },
187            },
188            Greater => Some(Greater),
189            Less => Some(Less),
190        }
191    }
192}
193
194impl Ord for Endpoint {
195    fn cmp(&self, other: &Self) -> Ordering {
196        self.partial_cmp(other).unwrap()
197    }
198}
199
200impl std::fmt::Debug for Endpoint {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        write!(f, "{}", self)
203    }
204}
205
206
207impl std::fmt::Display for Endpoint {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        let mut result = String::new();
210
211        result += match self.area {
212            EndpointArea::Lan => "L", // LOCAL
213            EndpointArea::Default => "D", // DEFAULT,
214            EndpointArea::Wan =>  "W", // WAN,
215            EndpointArea::Mapped => "M" // MAPPED WAN,
216        };
217
218        result += match self.addr {
219            SocketAddr::V4(_) => "4",
220            SocketAddr::V6(_) => "6",
221        };
222
223        result += match self.protocol {
224            Protocol::Unk => "unk",
225            Protocol::Tcp => "tcp",
226            Protocol::Udp => "udp",
227        };
228
229        result += self.addr.to_string().as_str();
230
231        write!(f, "{}", &result)
232    }
233}
234
235impl FromStr for Endpoint {
236    type Err = BuckyError;
237    fn from_str(s: &str) -> Result<Self, Self::Err> {
238        let area = {
239            match &s[0..1] {
240                "W" => Ok(EndpointArea::Wan),
241                "M" => Ok(EndpointArea::Mapped),
242                "L" => Ok(EndpointArea::Lan),
243                "D" => Ok(EndpointArea::Default),
244                _ => Err(BuckyError::new(
245                    BuckyErrorCode::InvalidInput,
246                    "invalid endpoint string",
247                )),
248            }
249        }?;
250        let version_str = &s[1..2];
251
252        let protocol = {
253            match &s[2..5] {
254                "tcp" => Ok(Protocol::Tcp),
255                "udp" => Ok(Protocol::Udp),
256                _ => Err(BuckyError::new(
257                    BuckyErrorCode::InvalidInput,
258                    "invalid endpoint string",
259                )),
260            }
261        }?;
262
263        let addr = SocketAddr::from_str(&s[5..]).map_err(|_| {
264            BuckyError::new(BuckyErrorCode::InvalidInput, "invalid endpoint string")
265        })?;
266        if !(addr.is_ipv4() && version_str.eq("4") || addr.is_ipv6() && version_str.eq("6")) {
267            return Err(BuckyError::new(
268                BuckyErrorCode::InvalidInput,
269                "invalid endpoint string",
270            ));
271        }
272        Ok(Endpoint {
273            area,
274            protocol,
275            addr,
276        })
277    }
278}
279
280pub fn endpoints_to_string(eps: &[Endpoint]) -> String {
281    let mut s = "[".to_string();
282    if eps.len() > 0 {
283        s += eps[0].to_string().as_str();
284    }
285
286    if eps.len() > 1 {
287        for i in 1..eps.len() {
288            s += ",";
289            s += eps[i].to_string().as_str();
290        }
291    }
292    s += "]";
293    s
294}
295
296// 标识默认地址,socket bind的时候用0 地址
297const ENDPOINT_FLAG_DEFAULT: u8 = 1u8 << 0;
298
299const ENDPOINT_PROTOCOL_UNK: u8 = 0;
300const ENDPOINT_PROTOCOL_TCP: u8 = 1u8 << 1;
301const ENDPOINT_PROTOCOL_UDP: u8 = 1u8 << 2;
302
303const ENDPOINT_IP_VERSION_4: u8 = 1u8 << 3;
304const ENDPOINT_IP_VERSION_6: u8 = 1u8 << 4;
305const ENDPOINT_FLAG_STATIC_WAN: u8 = 1u8 << 6;
306const ENDPOINT_FLAG_SIGNED: u8 = 1u8 << 7;
307
308#[derive(Clone)]
309pub struct SignedEndpoint(Endpoint);
310
311impl From<Endpoint> for SignedEndpoint {
312    fn from(ep: Endpoint) -> Self {
313        Self(ep)
314    }
315}
316
317impl Into<Endpoint> for SignedEndpoint {
318    fn into(self) -> Endpoint {
319        self.0
320    }
321}
322
323impl AsRef<Endpoint> for SignedEndpoint {
324    fn as_ref(&self) -> &Endpoint {
325        &self.0
326    }
327}
328
329impl RawFixedBytes for Endpoint {
330    // TOFIX: C BDT union addr and addrV6 should not memcpy directly
331    fn raw_max_bytes() -> Option<usize> {
332        Some(1 + 2 + 16)
333    }
334    fn raw_min_bytes() -> Option<usize> {
335        Some(1 + 2 + 4)
336    }
337}
338
339impl RawFixedBytes for SignedEndpoint {
340    // TOFIX: C BDT union addr and addrV6 should not memcpy directly
341    fn raw_max_bytes() -> Option<usize> {
342        Some(1 + 2 + 16)
343    }
344    fn raw_min_bytes() -> Option<usize> {
345        Some(1 + 2 + 4)
346    }
347}
348
349impl Endpoint {
350    fn flags(&self) -> u8 {
351        let mut flags = 0u8;
352        flags |= match self.protocol {
353            Protocol::Tcp => ENDPOINT_PROTOCOL_TCP,
354            Protocol::Unk => ENDPOINT_PROTOCOL_UNK,
355            Protocol::Udp => ENDPOINT_PROTOCOL_UDP,
356        };
357        flags |= match self.is_static_wan() {
358            true => ENDPOINT_FLAG_STATIC_WAN,
359            false => 0,
360        };
361        flags |= match self.is_sys_default() {
362            true => ENDPOINT_FLAG_DEFAULT,
363            false => 0,
364        };
365        flags |= match self.addr {
366            SocketAddr::V4(_) => ENDPOINT_IP_VERSION_4,
367            SocketAddr::V6(_) => ENDPOINT_IP_VERSION_6,
368        };
369        flags
370    }
371
372    fn raw_encode_no_flags<'a>(&self, buf: &'a mut [u8]) -> Result<&'a mut [u8], BuckyError> {
373        buf[0..2].copy_from_slice(&self.addr.port().to_le_bytes()[..]);
374        let buf = &mut buf[2..];
375
376        match self.addr {
377            SocketAddr::V4(ref sock_addr) => {
378                if buf.len() < 4 {
379                    let msg = format!(
380                        "not enough buffer for encode SocketAddrV4, except={}, got={}",
381                        4,
382                        buf.len()
383                    );
384                    error!("{}", msg);
385
386                    Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg))
387                } else {
388                    unsafe {
389                        std::ptr::copy(
390                            sock_addr.ip().octets().as_ptr() as *const u8,
391                            buf.as_mut_ptr(),
392                            4,
393                        );
394                    }
395                    Ok(&mut buf[4..])
396                }
397            }
398            SocketAddr::V6(ref sock_addr) => {
399                if buf.len() < 16 {
400                    let msg = format!(
401                        "not enough buffer for encode SocketAddrV6, except={}, got={}",
402                        16,
403                        buf.len()
404                    );
405                    error!("{}", msg);
406
407                    Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg))
408                } else {
409                    buf[..16].copy_from_slice(&sock_addr.ip().octets());
410                    Ok(&mut buf[16..])
411                }
412            }
413        }
414    }
415
416    fn raw_decode_no_flags<'de>(
417        flags: u8,
418        buf: &'de [u8],
419    ) -> Result<(Self, &'de [u8]), BuckyError> {
420        let protocol = match flags & ENDPOINT_PROTOCOL_TCP {
421            0 => match flags & ENDPOINT_PROTOCOL_UDP {
422                0 => Protocol::Unk,
423                _ => Protocol::Udp,
424            },
425            _ => Protocol::Tcp,
426        };
427
428        let area = if flags & ENDPOINT_FLAG_STATIC_WAN != 0 {
429            EndpointArea::Wan
430        } else if flags & ENDPOINT_FLAG_DEFAULT != 0 {
431            EndpointArea::Default
432        } else {
433            EndpointArea::Lan
434        };
435
436
437        let port = {
438            let mut b = [0u8; 2];
439            b.copy_from_slice(&buf[0..2]);
440            u16::from_le_bytes(b)
441        };
442        let buf = &buf[2..];
443
444        let (addr, buf) = {
445            if flags & ENDPOINT_IP_VERSION_6 != 0 {
446                if buf.len() < 16 {
447                    let msg = format!(
448                        "not enough buffer for decode EndPoint6, except={}, got={}",
449                        16,
450                        buf.len()
451                    );
452                    error!("{}", msg);
453
454                    Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg))
455                } else {
456                    let mut s: [u8; 16] = [0; 16];
457                    s.copy_from_slice(&buf[..16]);
458                    // TOFIX: flow and scope
459                    let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(s), port, 0, 0));
460                    Ok((addr, &buf[16..]))
461                }
462            } else {
463                let addr = SocketAddr::V4(SocketAddrV4::new(
464                    Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]),
465                    port,
466                ));
467                Ok((addr, &buf[4..]))
468            }
469        }?;
470
471        let ep = Endpoint {
472            area,
473            protocol,
474            addr,
475        };
476        Ok((ep, buf))
477    }
478}
479
480impl RawEncode for Endpoint {
481    fn raw_measure(&self, _purpose: &Option<RawEncodePurpose>) -> Result<usize, BuckyError> {
482        match self.addr {
483            SocketAddr::V4(_) => Ok(1 + 2 + 4),
484            SocketAddr::V6(_) => Ok(1 + 2 + 16),
485        }
486    }
487
488    fn raw_encode<'a>(
489        &self,
490        buf: &'a mut [u8],
491        _purpose: &Option<RawEncodePurpose>,
492    ) -> Result<&'a mut [u8], BuckyError> {
493        let min_bytes = Self::raw_min_bytes().unwrap();
494        if buf.len() < min_bytes {
495            let msg = format!(
496                "not enough buffer for encode Endpoint, min bytes={}, got={}",
497                min_bytes,
498                buf.len()
499            );
500            error!("{}", msg);
501
502            return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
503        }
504
505        buf[0] = self.flags();
506        self.raw_encode_no_flags(&mut buf[1..])
507    }
508}
509
510impl<'de> RawDecode<'de> for Endpoint {
511    fn raw_decode(buf: &'de [u8]) -> Result<(Self, &'de [u8]), BuckyError> {
512        let min_bytes = Self::raw_min_bytes().unwrap();
513        if buf.len() < min_bytes {
514            let msg = format!(
515                "not enough buffer for decode Endpoint, min bytes={}, got={}",
516                min_bytes,
517                buf.len()
518            );
519            error!("{}", msg);
520
521            return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
522        }
523        let flags = buf[0];
524        Self::raw_decode_no_flags(flags, &buf[1..])
525    }
526}
527
528impl RawEncode for SignedEndpoint {
529    fn raw_measure(&self, purpose: &Option<RawEncodePurpose>) -> Result<usize, BuckyError> {
530        self.0.raw_measure(purpose)
531    }
532
533    fn raw_encode<'a>(
534        &self,
535        buf: &'a mut [u8],
536        purpose: &Option<RawEncodePurpose>,
537    ) -> Result<&'a mut [u8], BuckyError> {
538        let bytes = self.raw_measure(purpose)?;
539        if buf.len() < bytes {
540            let msg = format!(
541                "not enough buffer for encode SignedEndpoint, except={}, got={}",
542                bytes,
543                buf.len()
544            );
545            error!("{}", msg);
546
547            return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
548        }
549
550        buf[0] = self.0.flags() | ENDPOINT_FLAG_SIGNED;
551        self.0.raw_encode_no_flags(&mut buf[1..])
552    }
553}
554
555impl<'de> RawDecode<'de> for SignedEndpoint {
556    fn raw_decode(buf: &'de [u8]) -> Result<(Self, &'de [u8]), BuckyError> {
557        let min_bytes = Self::raw_min_bytes().unwrap();
558        if buf.len() < min_bytes {
559            let msg = format!(
560                "not enough buffer for decode SignedEndpoint, min bytes={}, got={}",
561                min_bytes,
562                buf.len()
563            );
564            error!("{}", msg);
565
566            return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
567        }
568        let flags = buf[0];
569        if flags & ENDPOINT_FLAG_SIGNED == 0 {
570            return Err(BuckyError::new(
571                BuckyErrorCode::InvalidParam,
572                "without sign flag",
573            ));
574        }
575        let (ep, buf) = Endpoint::raw_decode_no_flags(flags, &buf[1..])?;
576        Ok((SignedEndpoint(ep), buf))
577    }
578}
579
580#[cfg(test)]
581mod test {
582    use crate::*;
583    use async_std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
584    use std::convert::From;
585    //use std::path::Path;
586
587    #[test]
588    fn test_codec() {
589        let ep = Endpoint::default();
590        let v = ep.to_vec().unwrap();
591        let ep2 = Endpoint::clone_from_slice(&v).unwrap();
592        assert_eq!(ep, ep2);
593
594        let ep: Endpoint = (
595            Protocol::Tcp,
596            SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(127, 11, 22, 33), 4)),
597        )
598            .into();
599        let v = ep.to_vec().unwrap();
600        let ep2 = Endpoint::clone_from_slice(&v).unwrap();
601        assert_eq!(ep, ep2);
602    }
603    #[test]
604    fn endpoint() {
605        let ep: Endpoint = (
606            Protocol::Tcp,
607            SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(127, 1, 2, 3), 4)),
608        )
609            .into();
610        println!("{}", ep);
611
612        // let p = Path::new("f:\\temp\\endpoint.obj");
613        // if p.parent().unwrap().exists() {
614        //     ep.encode_to_file(p, false);
615        // }
616
617        let ep: Endpoint = (
618            Protocol::Tcp,
619            SocketAddr::from(SocketAddrV6::new(
620                Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8),
621                9,
622                10,
623                11,
624            )),
625        )
626            .into();
627        println!("{}", ep);
628        // let p = Path::new("f:\\temp\\endpoint2.obj");
629        // if p.parent().unwrap().exists() {
630        //     ep.encode_to_file(p, false);
631        // }
632    }
633}