adns_proto/
name.rs

1use core::fmt;
2use std::{
3    hash::{Hash, Hasher},
4    str::FromStr,
5};
6
7use smallvec::SmallVec;
8use thiserror::Error;
9
10#[derive(Clone, Debug, Default, Eq)]
11pub struct Name {
12    full: String,
13    segment_indices: SmallVec<[u16; 8]>,
14}
15
16#[cfg(feature = "serde")]
17impl serde::Serialize for Name {
18    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
19        self.full.serialize(serializer)
20    }
21}
22
23#[cfg(feature = "serde")]
24impl<'de> serde::Deserialize<'de> for Name {
25    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
26        let raw = String::deserialize(deserializer)?;
27        raw.parse().map_err(serde::de::Error::custom)
28    }
29}
30
31impl PartialEq for Name {
32    fn eq(&self, other: &Self) -> bool {
33        self.full == other.full
34    }
35}
36
37impl PartialOrd for Name {
38    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
39        self.full.partial_cmp(&other.full)
40    }
41}
42
43impl Ord for Name {
44    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
45        self.full.cmp(&other.full)
46    }
47}
48
49impl AsRef<str> for Name {
50    fn as_ref(&self) -> &str {
51        &self.full
52    }
53}
54
55impl fmt::Display for Name {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        write!(f, "{}", self.full)
58    }
59}
60
61impl Hash for Name {
62    fn hash<H: Hasher>(&self, state: &mut H) {
63        self.full.hash(state);
64    }
65}
66
67#[derive(Error, Debug)]
68pub enum NameParseError {
69    #[error("name label segment over 63 char long")]
70    NameLabelTooLong,
71    #[error("name over 255 char long")]
72    NameTooLong,
73}
74
75impl FromStr for Name {
76    type Err = NameParseError;
77
78    fn from_str(s: &str) -> Result<Self, Self::Err> {
79        if s.len() > 255 {
80            return Err(NameParseError::NameTooLong);
81        }
82        let mut out = Name {
83            full: String::with_capacity(s.len() + 1),
84            segment_indices: Default::default(),
85        };
86        for x in s.split('.') {
87            out.push_segment(x)?;
88        }
89        Ok(out)
90    }
91}
92
93impl Name {
94    pub fn ends_with(&self, other: &Name) -> bool {
95        if self.full == other.full {
96            return true;
97        }
98        if self.segment_indices.len() < other.segment_indices.len() {
99            return false;
100        }
101        for (self_segment, other_segment) in self.segments().rev().zip(other.segments().rev()) {
102            if self_segment != other_segment {
103                return false;
104            }
105        }
106        true
107    }
108
109    /// matches ** -> any number of segments (prefix only), *+ -> matches one or more segments, * -> any one segment, @ -> empty
110    pub fn contains(&self, other: &Name) -> bool {
111        if self.full == other.full {
112            return true;
113        }
114
115        let mut segments = self.segments().peekable();
116
117        // wildcard prefix
118        if segments.peek().copied() == Some("**") {
119            segments.next().unwrap();
120            if other.segment_indices.len() < self.segment_indices.len().saturating_sub(1) {
121                return false;
122            }
123            for (other, ours) in other.segments().rev().zip(segments.rev()) {
124                if ours != "*" && other != ours {
125                    return false;
126                }
127            }
128        } else if segments.peek().copied() == Some("*+") {
129            segments.next().unwrap();
130            if other.segment_indices.len() < self.segment_indices.len() {
131                return false;
132            }
133            for (other, ours) in other.segments().rev().zip(segments.rev()) {
134                if ours != "*" && other != ours {
135                    return false;
136                }
137            }
138        } else {
139            if other.segment_indices.len() != self.segment_indices.len() {
140                return false;
141            }
142            for (other, ours) in other.segments().zip(segments) {
143                if ours != "*" && other != ours {
144                    return false;
145                }
146            }
147        }
148
149        true
150    }
151
152    pub fn from_segments<S: AsRef<str>>(
153        segments: impl IntoIterator<Item = S>,
154    ) -> Result<Self, NameParseError> {
155        let mut out = Self::default();
156        for segment in segments {
157            out.push_segment(segment.as_ref())?;
158        }
159        if out.full.len() > 255 {
160            return Err(NameParseError::NameTooLong);
161        }
162        Ok(out)
163    }
164
165    pub fn push_segment(&mut self, segment: impl AsRef<str>) -> Result<(), NameParseError> {
166        let segment = segment.as_ref();
167        if segment.is_empty() {
168            return Ok(());
169        }
170        if segment.len() > 63 {
171            return Err(NameParseError::NameLabelTooLong);
172        }
173        self.full.reserve(segment.len() + 1);
174        if !self.full.is_empty() {
175            self.full.push('.');
176        }
177        let start = self.full.len();
178        self.full.push_str(segment);
179        self.full[start..].make_ascii_lowercase();
180        if self.full.len() > 255 {
181            return Err(NameParseError::NameTooLong);
182        }
183        self.segment_indices.push(start.try_into().unwrap());
184        Ok(())
185    }
186
187    pub fn segments(&self) -> SegmentIterator<'_> {
188        SegmentIterator {
189            name: self,
190            index: 0,
191            end_index: self.segment_indices.len(),
192        }
193    }
194}
195
196pub struct SegmentIterator<'a> {
197    name: &'a Name,
198    index: usize,
199    end_index: usize,
200}
201
202impl<'a> Iterator for SegmentIterator<'a> {
203    type Item = &'a str;
204
205    fn next(&mut self) -> Option<Self::Item> {
206        if self.index >= self.end_index {
207            return None;
208        }
209        let index = *self.name.segment_indices.get(self.index)?;
210        let end = self
211            .name
212            .segment_indices
213            .get(self.index + 1)
214            .map(|x| x.saturating_sub(1))
215            .unwrap_or(self.name.full.len() as u16);
216        self.index += 1;
217        self.name.full.get(index as usize..end as usize)
218    }
219}
220
221impl<'a> DoubleEndedIterator for SegmentIterator<'a> {
222    fn next_back(&mut self) -> Option<Self::Item> {
223        if self.index >= self.end_index {
224            return None;
225        }
226        let end = if self.end_index >= self.name.segment_indices.len() {
227            self.name.full.len() as u16
228        } else {
229            self.name
230                .segment_indices
231                .get(self.end_index)
232                .unwrap()
233                .saturating_sub(1)
234        };
235        let index = self
236            .end_index
237            .checked_sub(1)
238            .map(|x| self.name.segment_indices.get(x).unwrap())
239            .copied()
240            .unwrap_or_default();
241        self.end_index = self.end_index.checked_sub(1).unwrap();
242        self.name.full.get(index as usize..end as usize)
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_name() {
252        let name: Name = "test.com".parse().unwrap();
253        {
254            let mut iter = name.segments();
255            assert_eq!(iter.next(), Some("test"));
256            assert_eq!(iter.next(), Some("com"));
257            assert_eq!(iter.next(), None);
258        }
259        {
260            let mut iter = name.segments().rev();
261            assert_eq!(iter.next(), Some("com"));
262            assert_eq!(iter.next(), Some("test"));
263            assert_eq!(iter.next(), None);
264        }
265
266        let name_container: Name = "**.test.com".parse().unwrap();
267        assert!(name_container.contains(&name));
268        let name_container: Name = "*+.test.com".parse().unwrap();
269        assert!(!name_container.contains(&name));
270        let name2: Name = "west.test.com".parse().unwrap();
271        assert!(name_container.contains(&name2));
272        let name_container: Name = "west.*.com".parse().unwrap();
273        assert!(name_container.contains(&name2));
274        assert!(!name_container.contains(&name));
275
276        assert!(name2.ends_with(&name));
277        assert!(!name.ends_with(&name2));
278    }
279}