dns_codec/
lib.rs

1use num_bigint::BigUint;
2use thiserror::Error;
3
4const FORWARD: &[u8] = &[
5    b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n', b'o', b'p',
6    b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', b'z', b'0', b'1', b'2', b'3', b'4', b'5',
7    b'6', b'7', b'8', b'9', b'-', b'.',
8];
9
10const REVERSE: &[u8] = &[
11    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
12    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 36, 37, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 0, 0,
13    0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
14    22, 23, 24, 25, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
15    18, 19, 20, 21, 22, 23, 24, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
16    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
17    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
18    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
19    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
20];
21
22const RADIX_ANY: usize = 38;
23const RADIX_BEGIN: usize = 26;
24const RADIX_END: usize = 36;
25
26const MAX_LABEL: usize = 61;
27
28#[derive(Error, Debug)]
29pub enum CodecError {
30    #[error("invalid data")]
31    InvalidData,
32}
33
34fn encode(data: &[u8]) -> Vec<u8> {
35    let mut value = BigUint::from_bytes_le(data);
36    value |= BigUint::from(1u8) << data.len() * 8;
37    let mut result = Vec::new();
38    let mut radix = RADIX_BEGIN;
39    let mut len = 0;
40    while value != BigUint::ZERO {
41        let rem: usize = (&value % radix).try_into().unwrap();
42        value /= radix;
43        let mut c = FORWARD[rem];
44        result.push(c);
45        len += 1;
46        if len == MAX_LABEL {
47            c = b'.';
48            result.push(c);
49        };
50        if c == b'.' {
51            radix = RADIX_BEGIN;
52            len = 0;
53        } else if c == b'-' || len == MAX_LABEL - 1 {
54            radix = RADIX_END;
55        } else {
56            radix = RADIX_ANY;
57        };
58    }
59    result.push(FORWARD[0]);
60    result
61}
62
63pub fn encode_string(data: &[u8]) -> String {
64    encode(data).into_iter().map(|x| char::from(x)).collect()
65}
66
67pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
68    let len = data.len();
69    if len < 2 {
70        return Err(CodecError::InvalidData);
71    }
72    if data[len - 1] != b'a' && data[len - 1] != b'A' {
73        return Err(CodecError::InvalidData);
74    }
75    let mut value = BigUint::ZERO;
76    let mut radix = RADIX_ANY;
77    for idx in (0..len - 1).rev() {
78        if data[idx] == b'.' && idx >= MAX_LABEL && !data[idx - MAX_LABEL..idx].contains(&b'.') {
79            radix = RADIX_END;
80            continue;
81        }
82        if idx > 0 && data[idx - 1] == b'-' {
83            radix = RADIX_END;
84        }
85        if idx == 0 || data[idx - 1] == b'.' {
86            radix = RADIX_BEGIN;
87        }
88        value *= radix;
89        value += REVERSE[data[idx] as usize];
90        radix = RADIX_ANY;
91    }
92    let mut result = value.to_bytes_le();
93    result.pop();
94    Ok(result)
95}
96
97pub fn decode_string(string: &str) -> Result<Vec<u8>, CodecError> {
98    decode(string.as_bytes())
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use rand::distributions::Standard;
105    use rand::{thread_rng, Rng};
106
107    fn validate_encoding(value: &[u8]) {
108        let encoded = encode_string(&value);
109        let decoded = decode_string(&encoded).unwrap();
110        assert_eq!(value, decoded, "{value:?} => {encoded} => {decoded:?}");
111        assert!(
112            encoded.len() <= 248,
113            "encoded length exceeds 248: {value:?} ({}) => {encoded} ({})",
114            value.len(),
115            encoded.len()
116        );
117
118        let upper = encoded.to_uppercase();
119        let decoded = decode_string(&upper).unwrap();
120        assert_eq!(value, decoded, "{value:?} => {upper} => {decoded:?}");
121
122        let mut prev = '.';
123        let mut len = 0;
124        for c in encoded.chars() {
125            if prev == '.' {
126                assert!(
127                    c.is_ascii_alphabetic(),
128                    "label start is not alphabetic: {encoded}"
129                );
130                len = -1;
131            };
132            if prev == '-' {
133                assert_ne!(c, '-', "consecutive '-': {encoded}");
134            }
135
136            if c == '.' {
137                assert!(
138                    prev.is_ascii_alphanumeric(),
139                    "label end is not alphanumeric: {encoded}"
140                );
141            } else if c != '-' {
142                assert!(c.is_ascii_alphanumeric(), "invalid characters: {encoded}");
143            }
144
145            prev = c;
146            len += 1;
147            assert!(
148                len as usize <= MAX_LABEL,
149                "label exceeds max length of {MAX_LABEL}: {encoded}"
150            );
151        }
152        assert_ne!(prev, '-', "encoding ends in '-': {encoded}");
153        assert_ne!(prev, '.', "encoding ends in '.': {encoded}");
154    }
155
156    #[test]
157    fn can_encode_and_decode_empty_vector() {
158        validate_encoding(&vec![]);
159    }
160
161    #[test]
162    fn can_encode_and_decode_all_1_byte_values() {
163        for i in 0u8..=255u8 {
164            validate_encoding(&vec![i]);
165        }
166    }
167
168    #[test]
169    fn can_encode_and_decode_all_2_byte_values() {
170        for i in 0..65536 {
171            let value = vec![(i / 256) as u8, (i % 256) as u8];
172            validate_encoding(&value);
173        }
174    }
175
176    #[test]
177    fn can_encode_and_decode_the_first_million_3_byte_values() {
178        for i in 0..1_000_000 {
179            let value = vec![(i / 65536) as u8, (i / 256) as u8, (i % 256) as u8];
180            validate_encoding(&value);
181        }
182    }
183
184    #[test]
185    fn can_encode_and_decode_100k_random_values() {
186        for _ in 0..100_000 {
187            let len = thread_rng().gen_range(3..159);
188            let value: Vec<u8> = thread_rng().sample_iter(Standard).take(len).collect();
189            validate_encoding(&value);
190        }
191    }
192
193    #[test]
194    fn can_encode_and_decode_100k_max_len_random_values() {
195        for _ in 0..100_000 {
196            let value: Vec<u8> = thread_rng().sample_iter(Standard).take(158).collect();
197            validate_encoding(&value);
198        }
199    }
200
201    #[test]
202    fn can_encode_and_decode_all_scaled_unit_vectors() {
203        for len in 3..160 {
204            for value in 0u8..=255u8 {
205                let data = vec![value; len];
206                validate_encoding(&data);
207            }
208        }
209    }
210
211    #[test]
212    fn can_decode_and_reencode_max_len_short_labels() {
213        for s in [
214            "a.", "b.", "c.", "d.", "e.", "f.", "g.", "h.", "i.", "j.", "k.", "l.", "m.", "n.",
215            "o.", "p.", "q.", "r.", "s.", "t.", "u.", "v.", "w.", "x.", "y.", "z.",
216        ] {
217            let mut encoded = String::from(s.repeat(124).strip_suffix(".").unwrap());
218            encoded.push('a');
219            let decoded = decode_string(&encoded).unwrap();
220            assert!(
221                decoded.len() >= 152,
222                "{encoded} => {decoded:?} ({})",
223                decoded.len()
224            );
225            let reencoded = encode_string(&decoded);
226            if encoded.starts_with("b") {
227                assert_eq!(encoded, reencoded);
228                assert_eq!(decoded.len(), 153);
229            } else {
230                assert_ne!(encoded, reencoded);
231            }
232            let redecoded = decode_string(&reencoded).unwrap();
233            assert_eq!(
234                decoded, redecoded,
235                "{encoded} => {decoded:?} => {reencoded} => {redecoded:?}"
236            );
237            assert!(
238                reencoded.len() <= 248,
239                "{encoded} => {decoded:?} => {reencoded} ({})",
240                reencoded.len()
241            );
242        }
243    }
244}