cyfs_base/base/
endpoint.rs

1pub use std::net::{IpAddr, SocketAddr};
2use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
3use std::str::FromStr;
4
5use crate::codec::{RawDecode, RawEncode, RawEncodePurpose, RawFixedBytes};
6use crate::*;
7use std::cmp::Ordering;
8
9#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
10pub enum Protocol {
11    Unk = 0,
12    Tcp = 1,
13    Udp = 2,
14}
15
16#[derive(Debug, PartialEq, Eq, Clone, Copy)]
17pub enum EndpointArea {
18    Lan, 
19    Default, 
20    Wan, 
21    Mapped
22}
23
24#[derive(Copy, Clone, Eq)]
25pub struct Endpoint {
26    area: EndpointArea,
27    protocol: Protocol,
28    addr: SocketAddr,
29}
30
31impl Endpoint {
32    pub fn protocol(&self) -> Protocol {
33        self.protocol
34    }
35    pub fn set_protocol(&mut self, p: Protocol) {
36        self.protocol = p
37    }
38
39    pub fn addr(&self) -> &SocketAddr {
40        &self.addr
41    }
42
43    pub fn mut_addr(&mut self) -> &mut SocketAddr {
44        &mut self.addr
45    }
46
47    pub fn is_same_ip_version(&self, other: &Endpoint) -> bool {
48        self.addr.is_ipv4() == other.addr.is_ipv4()
49    }
50
51    pub fn is_same_ip_addr(&self, other: &Endpoint) -> bool {
52        let mut self_ip = self.addr;
53        self_ip.set_port(0);
54        let mut other_ip = other.addr;
55        other_ip.set_port(0);
56        self_ip == other_ip
57    }
58
59    pub fn default_of(ep: &Endpoint) -> Self {
60        match ep.protocol {
61            Protocol::Tcp => Self::default_tcp(ep),
62            Protocol::Udp => Self::default_udp(ep),
63            _ => Self {
64                area: EndpointArea::Lan,
65                protocol: Protocol::Unk,
66                addr: match ep.addr().is_ipv4() {
67                    true => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
68                    false => SocketAddr::V6(SocketAddrV6::new(
69                        Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
70                        0,
71                        0,
72                        0,
73                    )),
74                },
75            },
76        }
77    }
78
79    pub fn default_tcp(ep: &Endpoint) -> Self {
80        Self {
81            area: EndpointArea::Lan,
82            protocol: Protocol::Tcp,
83            addr: match ep.addr().is_ipv4() {
84                true => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
85                false => SocketAddr::V6(SocketAddrV6::new(
86                    Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
87                    0,
88                    0,
89                    0,
90                )),
91            },
92        }
93    }
94
95    pub fn default_udp(ep: &Endpoint) -> Self {
96        Self {
97            area: EndpointArea::Lan,
98            protocol: Protocol::Udp,
99            addr: match ep.addr().is_ipv4() {
100                true => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
101                false => SocketAddr::V6(SocketAddrV6::new(
102                    Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
103                    0,
104                    0,
105                    0,
106                )),
107            },
108        }
109    }
110
111    pub fn is_udp(&self) -> bool {
112        self.protocol == Protocol::Udp
113    }
114    pub fn is_tcp(&self) -> bool {
115        self.protocol == Protocol::Tcp
116    }
117    pub fn is_sys_default(&self) -> bool {
118        self.area == EndpointArea::Default
119    }
120    pub fn is_static_wan(&self) -> bool {
121        self.area == EndpointArea::Wan
122            || self.area == EndpointArea::Mapped
123    }
124
125    pub fn is_mapped_wan(&self) -> bool {
126        self.area == EndpointArea::Mapped
127    }
128
129    pub fn set_area(&mut self, area: EndpointArea) {
130        self.area = area;
131    }
132}
133
134impl Default for Endpoint {
135    fn default() -> Self {
136        Self {
137            area: EndpointArea::Lan,
138            protocol: Protocol::Unk,
139            addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
140        }
141    }
142}
143
144impl From<(Protocol, SocketAddr)> for Endpoint {
145    fn from(ps: (Protocol, SocketAddr)) -> Self {
146        Self {
147            area: EndpointArea::Lan,
148            protocol: ps.0,
149            addr: ps.1,
150        }
151    }
152}
153
154impl From<(Protocol, IpAddr, u16)> for Endpoint {
155    fn from(piu: (Protocol, IpAddr, u16)) -> Self {
156        Self {
157            area: EndpointArea::Lan,
158            protocol: piu.0,
159            addr: SocketAddr::new(piu.1, piu.2),
160        }
161    }
162}
163
164impl ToSocketAddrs for Endpoint {
165    type Iter = <SocketAddr as ToSocketAddrs>::Iter;
166    fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
167        self.addr.to_socket_addrs()
168    }
169}
170
171impl PartialEq for Endpoint {
172    fn eq(&self, other: &Endpoint) -> bool {
173        self.protocol == other.protocol && self.addr == other.addr
174    }
175}
176
177impl PartialOrd for Endpoint {
178    fn partial_cmp(&self, other: &Endpoint) -> Option<std::cmp::Ordering> {
179        use std::cmp::Ordering::*;
180        match self.protocol.partial_cmp(&other.protocol).unwrap() {
181            Equal => match self.addr.ip().partial_cmp(&other.addr().ip()) {
182                None => self.addr.port().partial_cmp(&other.addr.port()),
183                Some(ord) => match ord {
184                    Greater => Some(Greater),
185                    Less => Some(Less),
186                    Equal => self.addr.port().partial_cmp(&other.addr.port()),
187                },
188            },
189            Greater => Some(Greater),
190            Less => Some(Less),
191        }
192    }
193}
194
195impl Ord for Endpoint {
196    fn cmp(&self, other: &Self) -> Ordering {
197        self.partial_cmp(other).unwrap()
198    }
199}
200
201impl std::fmt::Debug for Endpoint {
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        write!(f, "{}", self)
204    }
205}
206
207
208impl std::fmt::Display for Endpoint {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        let mut result = String::new();
211
212        result += match self.area {
213            EndpointArea::Lan => "L", // LOCAL
214            EndpointArea::Default => "D", // DEFAULT, 
215            EndpointArea::Wan =>  "W", // WAN, 
216            EndpointArea::Mapped => "M" // MAPPED WAN, 
217        };
218
219        result += match self.addr {
220            SocketAddr::V4(_) => "4",
221            SocketAddr::V6(_) => "6",
222        };
223
224        result += match self.protocol {
225            Protocol::Unk => "unk",
226            Protocol::Tcp => "tcp",
227            Protocol::Udp => "udp",
228        };
229
230        result += self.addr.to_string().as_str();
231
232        write!(f, "{}", &result)
233    }
234}
235
236impl FromStr for Endpoint {
237    type Err = BuckyError;
238    fn from_str(s: &str) -> Result<Self, Self::Err> {
239        let area = {
240            match &s[0..1] {
241                "W" => Ok(EndpointArea::Wan),
242                "M" => Ok(EndpointArea::Mapped),
243                "L" => Ok(EndpointArea::Lan),
244                "D" => Ok(EndpointArea::Default),
245                _ => Err(BuckyError::new(
246                    BuckyErrorCode::InvalidInput,
247                    "invalid endpoint string",
248                )),
249            }
250        }?;
251        let version_str = &s[1..2];
252
253        let protocol = {
254            match &s[2..5] {
255                "tcp" => Ok(Protocol::Tcp),
256                "udp" => Ok(Protocol::Udp),
257                _ => Err(BuckyError::new(
258                    BuckyErrorCode::InvalidInput,
259                    "invalid endpoint string",
260                )),
261            }
262        }?;
263
264        let addr = SocketAddr::from_str(&s[5..]).map_err(|_| {
265            BuckyError::new(BuckyErrorCode::InvalidInput, "invalid endpoint string")
266        })?;
267        if !(addr.is_ipv4() && version_str.eq("4") || addr.is_ipv6() && version_str.eq("6")) {
268            return Err(BuckyError::new(
269                BuckyErrorCode::InvalidInput,
270                "invalid endpoint string",
271            ));
272        }
273        Ok(Endpoint {
274            area, 
275            protocol,
276            addr,
277        })
278    }
279}
280
281pub fn endpoints_to_string(eps: &[Endpoint]) -> String {
282    let mut s = "[".to_string();
283    if eps.len() > 0 {
284        s += eps[0].to_string().as_str();
285    }
286
287    if eps.len() > 1 {
288        for i in 1..eps.len() {
289            s += ",";
290            s += eps[i].to_string().as_str();
291        }
292    }
293    s += "]";
294    s
295}
296
297// 标识默认地址,socket bind的时候用0 地址
298const ENDPOINT_FLAG_DEFAULT: u8 = 1u8 << 0;
299
300const ENDPOINT_PROTOCOL_UNK: u8 = 0;
301const ENDPOINT_PROTOCOL_TCP: u8 = 1u8 << 1;
302const ENDPOINT_PROTOCOL_UDP: u8 = 1u8 << 2;
303
304const ENDPOINT_IP_VERSION_4: u8 = 1u8 << 3;
305const ENDPOINT_IP_VERSION_6: u8 = 1u8 << 4;
306const ENDPOINT_FLAG_STATIC_WAN: u8 = 1u8 << 6;
307const ENDPOINT_FLAG_SIGNED: u8 = 1u8 << 7;
308
309#[derive(Clone)]
310pub struct SignedEndpoint(Endpoint);
311
312impl From<Endpoint> for SignedEndpoint {
313    fn from(ep: Endpoint) -> Self {
314        Self(ep)
315    }
316}
317
318impl Into<Endpoint> for SignedEndpoint {
319    fn into(self) -> Endpoint {
320        self.0
321    }
322}
323
324impl AsRef<Endpoint> for SignedEndpoint {
325    fn as_ref(&self) -> &Endpoint {
326        &self.0
327    }
328}
329
330impl RawFixedBytes for Endpoint {
331    // TOFIX: C BDT union addr and addrV6 should not memcpy directly
332    fn raw_max_bytes() -> Option<usize> {
333        Some(1 + 2 + 16)
334    }
335    fn raw_min_bytes() -> Option<usize> {
336        Some(1 + 2 + 4)
337    }
338}
339
340impl RawFixedBytes for SignedEndpoint {
341    // TOFIX: C BDT union addr and addrV6 should not memcpy directly
342    fn raw_max_bytes() -> Option<usize> {
343        Some(1 + 2 + 16)
344    }
345    fn raw_min_bytes() -> Option<usize> {
346        Some(1 + 2 + 4)
347    }
348}
349
350impl Endpoint {
351    fn flags(&self) -> u8 {
352        let mut flags = 0u8;
353        flags |= match self.protocol {
354            Protocol::Tcp => ENDPOINT_PROTOCOL_TCP,
355            Protocol::Unk => ENDPOINT_PROTOCOL_UNK,
356            Protocol::Udp => ENDPOINT_PROTOCOL_UDP,
357        };
358        flags |= match self.is_static_wan() {
359            true => ENDPOINT_FLAG_STATIC_WAN,
360            false => 0,
361        };
362        flags |= match self.is_sys_default() {
363            true => ENDPOINT_FLAG_DEFAULT,
364            false => 0,
365        };
366        flags |= match self.addr {
367            SocketAddr::V4(_) => ENDPOINT_IP_VERSION_4,
368            SocketAddr::V6(_) => ENDPOINT_IP_VERSION_6,
369        };
370        flags
371    }
372
373    fn raw_encode_no_flags<'a>(&self, buf: &'a mut [u8]) -> Result<&'a mut [u8], BuckyError> {
374        buf[0..2].copy_from_slice(&self.addr.port().to_le_bytes()[..]);
375        let buf = &mut buf[2..];
376
377        match self.addr {
378            SocketAddr::V4(ref sock_addr) => {
379                if buf.len() < 4 {
380                    let msg = format!(
381                        "not enough buffer for encode SocketAddrV4, except={}, got={}",
382                        4,
383                        buf.len()
384                    );
385                    error!("{}", msg);
386
387                    Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg))
388                } else {
389                    unsafe {
390                        std::ptr::copy(
391                            sock_addr.ip().octets().as_ptr() as *const u8,
392                            buf.as_mut_ptr(),
393                            4,
394                        );
395                    }
396                    Ok(&mut buf[4..])
397                }
398            }
399            SocketAddr::V6(ref sock_addr) => {
400                if buf.len() < 16 {
401                    let msg = format!(
402                        "not enough buffer for encode SocketAddrV6, except={}, got={}",
403                        16,
404                        buf.len()
405                    );
406                    error!("{}", msg);
407
408                    Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg))
409                } else {
410                    buf[..16].copy_from_slice(&sock_addr.ip().octets());
411                    Ok(&mut buf[16..])
412                }
413            }
414        }
415    }
416
417    fn raw_decode_no_flags<'de>(
418        flags: u8,
419        buf: &'de [u8],
420    ) -> Result<(Self, &'de [u8]), BuckyError> {
421        let protocol = match flags & ENDPOINT_PROTOCOL_TCP {
422            0 => match flags & ENDPOINT_PROTOCOL_UDP {
423                0 => Protocol::Unk,
424                _ => Protocol::Udp,
425            },
426            _ => Protocol::Tcp,
427        };
428
429        let area = if flags & ENDPOINT_FLAG_STATIC_WAN != 0 {
430            EndpointArea::Wan
431        } else if flags & ENDPOINT_FLAG_DEFAULT != 0 {
432            EndpointArea::Default
433        } else {
434            EndpointArea::Lan
435        };
436       
437
438        let port = {
439            let mut b = [0u8; 2];
440            b.copy_from_slice(&buf[0..2]);
441            u16::from_le_bytes(b)
442        };
443        let buf = &buf[2..];
444
445        let (addr, buf) = {
446            if flags & ENDPOINT_IP_VERSION_6 != 0 {
447                if buf.len() < 16 {
448                    let msg = format!(
449                        "not enough buffer for decode EndPoint6, except={}, got={}",
450                        16,
451                        buf.len()
452                    );
453                    error!("{}", msg);
454
455                    Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg))
456                } else {
457                    let mut s: [u8; 16] = [0; 16];
458                    s.copy_from_slice(&buf[..16]);
459                    // TOFIX: flow and scope
460                    let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(s), port, 0, 0));
461                    Ok((addr, &buf[16..]))
462                }
463            } else {
464                let addr = SocketAddr::V4(SocketAddrV4::new(
465                    Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]),
466                    port,
467                ));
468                Ok((addr, &buf[4..]))
469            }
470        }?;
471
472        let ep = Endpoint {
473            area, 
474            protocol,
475            addr,
476        };
477        Ok((ep, buf))
478    }
479}
480
481impl RawEncode for Endpoint {
482    fn raw_measure(&self, _purpose: &Option<RawEncodePurpose>) -> Result<usize, BuckyError> {
483        match self.addr {
484            SocketAddr::V4(_) => Ok(1 + 2 + 4),
485            SocketAddr::V6(_) => Ok(1 + 2 + 16),
486        }
487    }
488
489    fn raw_encode<'a>(
490        &self,
491        buf: &'a mut [u8],
492        _purpose: &Option<RawEncodePurpose>,
493    ) -> Result<&'a mut [u8], BuckyError> {
494        let min_bytes = Self::raw_min_bytes().unwrap();
495        if buf.len() < min_bytes {
496            let msg = format!(
497                "not enough buffer for encode Endpoint, min bytes={}, got={}",
498                min_bytes,
499                buf.len()
500            );
501            error!("{}", msg);
502
503            return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
504        }
505
506        buf[0] = self.flags();
507        self.raw_encode_no_flags(&mut buf[1..])
508    }
509}
510
511impl<'de> RawDecode<'de> for Endpoint {
512    fn raw_decode(buf: &'de [u8]) -> Result<(Self, &'de [u8]), BuckyError> {
513        let min_bytes = Self::raw_min_bytes().unwrap();
514        if buf.len() < min_bytes {
515            let msg = format!(
516                "not enough buffer for decode Endpoint, min bytes={}, got={}",
517                min_bytes,
518                buf.len()
519            );
520            error!("{}", msg);
521
522            return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
523        }
524        let flags = buf[0];
525        Self::raw_decode_no_flags(flags, &buf[1..])
526    }
527}
528
529impl RawEncode for SignedEndpoint {
530    fn raw_measure(&self, purpose: &Option<RawEncodePurpose>) -> Result<usize, BuckyError> {
531        self.0.raw_measure(purpose)
532    }
533
534    fn raw_encode<'a>(
535        &self,
536        buf: &'a mut [u8],
537        purpose: &Option<RawEncodePurpose>,
538    ) -> Result<&'a mut [u8], BuckyError> {
539        let bytes = self.raw_measure(purpose)?;
540        if buf.len() < bytes {
541            let msg = format!(
542                "not enough buffer for encode SignedEndpoint, except={}, got={}",
543                bytes,
544                buf.len()
545            );
546            error!("{}", msg);
547
548            return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
549        }
550
551        buf[0] = self.0.flags() | ENDPOINT_FLAG_SIGNED;
552        self.0.raw_encode_no_flags(&mut buf[1..])
553    }
554}
555
556impl<'de> RawDecode<'de> for SignedEndpoint {
557    fn raw_decode(buf: &'de [u8]) -> Result<(Self, &'de [u8]), BuckyError> {
558        let min_bytes = Self::raw_min_bytes().unwrap();
559        if buf.len() < min_bytes {
560            let msg = format!(
561                "not enough buffer for decode SignedEndpoint, min bytes={}, got={}",
562                min_bytes,
563                buf.len()
564            );
565            error!("{}", msg);
566
567            return Err(BuckyError::new(BuckyErrorCode::OutOfLimit, msg));
568        }
569        let flags = buf[0];
570        if flags & ENDPOINT_FLAG_SIGNED == 0 {
571            return Err(BuckyError::new(
572                BuckyErrorCode::InvalidParam,
573                "without sign flag",
574            ));
575        }
576        let (ep, buf) = Endpoint::raw_decode_no_flags(flags, &buf[1..])?;
577        Ok((SignedEndpoint(ep), buf))
578    }
579}
580
581#[cfg(test)]
582mod test {
583    use crate::*;
584    use async_std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
585    use std::convert::From;
586    //use std::path::Path;
587
588    #[test]
589    fn test_codec() {
590        let ep = Endpoint::default();
591        let v = ep.to_vec().unwrap();
592        let ep2 = Endpoint::clone_from_slice(&v).unwrap();
593        assert_eq!(ep, ep2);
594
595        let ep: Endpoint = (
596            Protocol::Tcp,
597            SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(127, 11, 22, 33), 4)),
598        )
599            .into();
600        let v = ep.to_vec().unwrap();
601        let ep2 = Endpoint::clone_from_slice(&v).unwrap();
602        assert_eq!(ep, ep2);
603    }
604    #[test]
605    fn endpoint() {
606        let ep: Endpoint = (
607            Protocol::Tcp,
608            SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(127, 1, 2, 3), 4)),
609        )
610            .into();
611        println!("{}", ep);
612
613        // let p = Path::new("f:\\temp\\endpoint.obj");
614        // if p.parent().unwrap().exists() {
615        //     ep.encode_to_file(p, false);
616        // }
617
618        let ep: Endpoint = (
619            Protocol::Tcp,
620            SocketAddr::from(SocketAddrV6::new(
621                Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8),
622                9,
623                10,
624                11,
625            )),
626        )
627            .into();
628        println!("{}", ep);
629        // let p = Path::new("f:\\temp\\endpoint2.obj");
630        // if p.parent().unwrap().exists() {
631        //     ep.encode_to_file(p, false);
632        // }
633    }
634}