dns_message_parser/rr/
subtypes.rs

1use std::cmp::Ordering;
2use std::convert::TryFrom;
3use std::fmt::{Display, Formatter, Result as FmtResult};
4use std::net::{Ipv4Addr, Ipv6Addr};
5use std::slice::Iter;
6use thiserror::Error;
7
8const MASK: u8 = 0b1111_1111;
9
10fn check_ipv4_addr(ipv4_addr: &Ipv4Addr, prefix_length: u8) -> Result<(), AddressError> {
11    match 32.cmp(&prefix_length) {
12        Ordering::Less => Err(AddressError::Ipv4Prefix(prefix_length)),
13        Ordering::Equal => Ok(()),
14        Ordering::Greater => {
15            let octects = ipv4_addr.octets();
16            let index = (prefix_length / 8) as usize;
17            let remain = prefix_length % 8;
18
19            if (octects[index] & (MASK >> remain)) != 0 {
20                return Err(AddressError::Ipv4Mask(*ipv4_addr, prefix_length));
21            }
22
23            let (_, octects_right) = octects.split_at(index + 1);
24            for b in octects_right {
25                if *b != 0 {
26                    return Err(AddressError::Ipv4Mask(*ipv4_addr, prefix_length));
27                }
28            }
29
30            Ok(())
31        }
32    }
33}
34
35fn check_ipv6_addr(ipv6_addr: &Ipv6Addr, prefix_length: u8) -> Result<(), AddressError> {
36    match 128.cmp(&prefix_length) {
37        Ordering::Less => Err(AddressError::Ipv6Prefix(prefix_length)),
38        Ordering::Equal => Ok(()),
39        Ordering::Greater => {
40            let octects = ipv6_addr.octets();
41            let index = (prefix_length / 8) as usize;
42            let remain = prefix_length % 8;
43
44            if (octects[index] & (MASK >> remain)) != 0 {
45                return Err(AddressError::Ipv6Mask(*ipv6_addr, prefix_length));
46            }
47
48            let (_, octects_right) = octects.split_at(index + 1);
49            for b in octects_right {
50                if *b != 0 {
51                    return Err(AddressError::Ipv6Mask(*ipv6_addr, prefix_length));
52                }
53            }
54
55            Ok(())
56        }
57    }
58}
59
60#[derive(Debug, PartialEq, Clone, Copy, Eq, Hash)]
61pub enum Address {
62    Ipv4(Ipv4Addr),
63    Ipv6(Ipv6Addr),
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash, Error)]
67pub enum AddressError {
68    #[error("Prefix length is not between 0 and 32: {0}")]
69    Ipv4Prefix(u8),
70    #[error("IPv4 {0} does not fit {1} mask")]
71    Ipv4Mask(Ipv4Addr, u8),
72    #[error("Prefix length is not between 0 and 128: {0}")]
73    Ipv6Prefix(u8),
74    #[error("IPv6 {0} does not fit {1} mask")]
75    Ipv6Mask(Ipv6Addr, u8),
76}
77
78impl Address {
79    pub fn check_prefix(&self, prefix_length: u8) -> Result<(), AddressError> {
80        match self {
81            Address::Ipv4(ipv4_addr) => check_ipv4_addr(ipv4_addr, prefix_length),
82            Address::Ipv6(ipv6_addr) => check_ipv6_addr(ipv6_addr, prefix_length),
83        }
84    }
85}
86
87impl Display for Address {
88    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
89        match self {
90            Address::Ipv4(ipv4_addr) => ipv4_addr.fmt(f),
91            Address::Ipv6(ipv6_addr) => ipv6_addr.fmt(f),
92        }
93    }
94}
95
96impl Address {
97    pub const fn get_address_family_number(&self) -> AddressFamilyNumber {
98        match self {
99            Address::Ipv4(_) => AddressFamilyNumber::Ipv4,
100            Address::Ipv6(_) => AddressFamilyNumber::Ipv6,
101        }
102    }
103}
104
105try_from_enum_to_integer_without_display! {
106    #[repr(u16)]
107    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
108    pub enum AddressFamilyNumber {
109        Ipv4 = 0x0001,
110        Ipv6 = 0x0002,
111    }
112}
113
114impl Display for AddressFamilyNumber {
115    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
116        match self {
117            AddressFamilyNumber::Ipv4 => write!(f, "IPv4"),
118            AddressFamilyNumber::Ipv6 => write!(f, "IPv6"),
119        }
120    }
121}
122
123#[derive(Debug, PartialEq, Clone, Eq, Hash)]
124pub struct NonEmptyVec<T>(Vec<T>);
125
126impl<T> NonEmptyVec<T> {
127    pub fn iter(&self) -> Iter<'_, T> {
128        self.0.iter()
129    }
130}
131
132impl<T> TryFrom<Vec<T>> for NonEmptyVec<T> {
133    type Error = ();
134
135    fn try_from(vec: Vec<T>) -> Result<Self, Self::Error> {
136        if vec.is_empty() {
137            Err(())
138        } else {
139            Ok(NonEmptyVec(vec))
140        }
141    }
142}
143
144impl<T> From<NonEmptyVec<T>> for Vec<T> {
145    fn from(vec: NonEmptyVec<T>) -> Vec<T> {
146        vec.0
147    }
148}