amazon_cloudfront_client_routing_lib/
encode_decode.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{client_routing_label::EncodableData, errors::DecodeLengthError, bitwise::get_mask};
5
6const BASE32_ALPHABET: &[u8] = b"abcdefghijklmnopqrstuvwxyz234567";
7const BASE32_NUM_BITS_IN_CHAR: u8 = 5;
8const MAX_DNS_LABEL_SIZE: u8 = 63;
9
10/// Struct for encoding, decoding, and validating [`EncodableData`] with Base32.
11/// 
12/// Uses lowercase version of the RFC 4648 Base32 alphabet. Methods treat each
13/// set of 5 bits in [`EncodableData`] as a separate character. Invalid characters
14/// will be treated as 'a' instead of marking the entire label as invalid for
15/// efficiency. Contains no properties, for usage see individual functions or
16/// [`ClientRoutingLabel`](crate::client_routing_label::ClientRoutingLabel).
17#[derive(Copy, Clone, Debug)]
18pub struct Base32 {}
19
20impl Base32 {
21    /// Returns a lowercase Base32 string encoded from `encodable_data`.
22    /// 
23    /// Iterates over `encodable_data`, encoding bits from `value` until 
24    /// not enough bits remain to make a full char. Remaining bits are
25    /// then used in the subsequent iteration. After iterating over
26    /// everything, if there are not enough bits to make a char 0 will
27    /// be used to pad the left over bits. Encoding uses a lowercase
28    /// version of the RFC 4648 Base32 alphabet.
29    /// 
30    /// # Examples:
31    /// ```
32    /// use amazon_cloudfront_client_routing_lib::encode_decode::Base32;
33    /// use amazon_cloudfront_client_routing_lib::client_routing_label::EncodableData;
34    /// 
35    /// let encoding_system = Base32 {};
36    /// let encodable_data = &mut [
37    ///     EncodableData { // 0b01010 => "k"
38    ///         value: 10,
39    ///         num_bits: 5
40    ///     },
41    ///     EncodableData { // 0b00011_11011 => "d3"
42    ///         value: 123,
43    ///         num_bits: 10
44    ///     },
45    ///     EncodableData { // 0b0 => "a"
46    ///         value: 0,
47    ///         num_bits: 1
48    ///     },
49    /// ];
50    /// 
51    /// assert_eq!("kd3a", encoding_system.encode(encodable_data));
52    /// ```
53    pub fn encode(&self, encodable_data: &mut [EncodableData]) -> String {
54        let value_mask: u64 = get_mask(BASE32_NUM_BITS_IN_CHAR);
55        let mut encoded_data: Vec<char> = Vec::with_capacity(MAX_DNS_LABEL_SIZE as usize);
56        let mut value_to_encode: u8 = 0;
57        let mut num_bits_left_over: u8 = 0;
58        for data in encodable_data.iter_mut() {
59            while data.has_bits_for_char(BASE32_NUM_BITS_IN_CHAR - num_bits_left_over) {
60                value_to_encode += data.get_next_bits_to_encode(BASE32_NUM_BITS_IN_CHAR - num_bits_left_over);
61                encoded_data.push(BASE32_ALPHABET[value_to_encode as usize] as char);
62
63                num_bits_left_over = 0;
64                value_to_encode = 0;
65            }
66
67            value_to_encode |= ((data.value << (BASE32_NUM_BITS_IN_CHAR - (data.num_bits + num_bits_left_over))) & value_mask) as u8;
68            num_bits_left_over += data.num_bits;
69        }
70
71        if num_bits_left_over > 0 {
72            encoded_data.push(BASE32_ALPHABET[value_to_encode as usize] as char);
73        }
74
75        encoded_data.iter().collect()
76    }
77
78    /// Validates `client_routing_label` is the proper length to fit `total_num_bits`.
79    /// 
80    /// Calculates how many chars would be encoded for `total_num_bits` and then
81    /// checks if the `client_routing_label` has that many chars. Returns a [`Result`]
82    /// with '()' if it's valid or a [`DecodeLengthError`] if it's not valid.
83    /// 
84    /// # Examples:
85    /// ```
86    /// use amazon_cloudfront_client_routing_lib::encode_decode::Base32;
87    /// 
88    /// let encoding_system = Base32 {};
89    /// 
90    /// // valid
91    /// match encoding_system.is_valid_client_routing_label(145, b"abaaaaaaaaaaaaaaaaaaaackvj5oa") {
92    ///     Ok(()) => (),
93    ///     Err(_e) => panic!("Threw error when shouldn't have.")
94    /// };
95    /// 
96    /// // invalid
97    /// match encoding_system.is_valid_client_routing_label(145, b"abaaaaaaaaaaaaaaaaaaaackvj5oabcd") {
98    ///     Ok(()) => (),
99    ///     Err(e) => assert_eq!("Passed 32 - expected 29 characters", e.to_string())
100    /// };
101    /// ```
102    pub fn is_valid_client_routing_label(
103        &self,
104        total_num_bits: u8,
105        client_routing_label: &[u8],
106    ) -> Result<(), DecodeLengthError> {
107        if client_routing_label.len() as u8
108            != (total_num_bits + BASE32_NUM_BITS_IN_CHAR - 1) / BASE32_NUM_BITS_IN_CHAR
109        {
110            let e = DecodeLengthError {
111                num_chars: client_routing_label.len(),
112                expected_num_chars: ((total_num_bits + BASE32_NUM_BITS_IN_CHAR - 1)
113                    / BASE32_NUM_BITS_IN_CHAR) as usize,
114            };
115            return Err(e);
116        }
117
118        Ok(())
119    }
120
121    /// Sets `encodable_data` based on passed `encoded_label`.
122    /// 
123    /// Validates `encoded_label` is valid based on `total_num_bits`. If not valid,
124    /// returns a [`Result`] containing [`DecodeLengthError`]. If valid, iterates
125    /// over `encodable_data` and sets each value based on the label value. Invalid
126    /// characters in a label are treated as if they had a value of 0.
127    /// 
128    /// # Examples:
129    /// ```
130    /// use amazon_cloudfront_client_routing_lib::encode_decode::Base32;
131    /// use amazon_cloudfront_client_routing_lib::client_routing_label::EncodableData;
132    /// 
133    /// let encoding_system = Base32 {};
134    /// 
135    /// // valid
136    /// let encodable_data = &mut [
137    ///     EncodableData {
138    ///         value: 0,
139    ///         num_bits: 5
140    ///     },
141    ///     EncodableData {
142    ///         value: 0,
143    ///         num_bits: 10
144    ///     },
145    ///     EncodableData {
146    ///         value: 0,
147    ///         num_bits: 1
148    ///     },
149    /// ];
150    /// 
151    /// match encoding_system.decode(encodable_data, b"kd3a", 16) {
152    ///     Ok(()) => {
153    ///         assert_eq!(10, encodable_data[0].value);
154    ///         assert_eq!(123, encodable_data[1].value);
155    ///         assert_eq!(0, encodable_data[2].value);
156    ///     },
157    ///     Err(_e) => panic!("Threw error when shouldn't have.")
158    /// };
159    /// 
160    /// // invalid
161    /// match encoding_system.decode(encodable_data, b"kd3a", 10) {
162    ///     Ok(()) => panic!("Didn't throw error when should have."),
163    ///     Err(e) => assert_eq!("Passed 4 - expected 2 characters", e.to_string())
164    /// };
165    /// ```
166    pub fn decode(
167        &self,
168        encodable_data: &mut [EncodableData],
169        encoded_label: &[u8],
170        total_num_bits: u8,
171    ) -> Result<(), DecodeLengthError> {
172        match self.is_valid_client_routing_label(total_num_bits, encoded_label) {
173            Ok(()) => (),
174            Err(e) => return Err(e),
175        };
176
177        let mut label_values: Vec<u8> = encoded_label
178            .iter()
179            .map(|a| BASE32_ALPHABET.iter().position(|b| a == b).unwrap_or(0) as u8)
180            .collect();
181
182        let mut num_bits_in_char: u8 = BASE32_NUM_BITS_IN_CHAR;
183        let mut label_index: usize = 0;
184        for data in encodable_data.iter_mut() {
185            let original_num_bits: u8 = data.num_bits;
186            data.value = 0;
187            
188            while data.has_bits_for_char(num_bits_in_char) {
189                data.add_bits(num_bits_in_char, label_values[label_index]);
190                label_index += 1;
191                num_bits_in_char = BASE32_NUM_BITS_IN_CHAR;
192            }
193            
194            if data.num_bits > 0 {
195                num_bits_in_char -= data.num_bits;
196                data.add_bits(data.num_bits, label_values[label_index] >> num_bits_in_char);
197                label_values[label_index] &= get_mask(num_bits_in_char) as u8;
198            }
199
200            data.num_bits = original_num_bits;
201        }
202
203        Ok(())
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::client_routing_label::EncodableData;
211    
212    // All the data has values with bit size <= num_bits.
213    // Total bits is divisible by 5 and can be encoded with no padding.
214    #[test]
215    fn validate_encode_value_proper_size_no_padding_needed() {
216        let encoding_system = Base32 {};
217        let encodable_data = &mut [
218            EncodableData {
219                value: 0,
220                num_bits: 109,
221            },
222            EncodableData {
223                value: 1,
224                num_bits: 5,
225            },
226            EncodableData {
227                value: 0,
228                num_bits: 1,
229            },
230            EncodableData {
231                value: 6148494311290830848,
232                num_bits: 64,
233            },
234            EncodableData {
235                value: 24,
236                num_bits: 6,
237            },
238            EncodableData {
239                value: 957415,
240                num_bits: 20,
241            },
242        ];
243
244        assert_eq!("aaaaaaaaaaaaaaaaaaaaaackvj5oaaaaaaaay5g7h", encoding_system.encode(encodable_data));
245    }
246
247    // Some data has a value with bit size > num_bits.
248    // Total bits is divisible by 5 and can be encoded with no padding.
249    #[test]
250    fn validate_encode_value_proper_size_padding_needed() {
251        let encoding_system = Base32 {};
252        let encodable_data = &mut [
253            EncodableData {
254                value: 36,
255                num_bits: 12,
256            },
257            EncodableData {
258                value: 3734643,
259                num_bits: 22,
260            },
261            EncodableData {
262                value: 2367,
263                num_bits: 14,
264            },
265        ];
266
267        assert_eq!("ajhd6hgjh4", encoding_system.encode(encodable_data));
268    }
269
270    // All the data has values with bit size <= num_bits.
271    // Total bits is not divisible by 5 and will need padding to encode.
272    #[test]
273    fn validate_encode_value_too_large_no_padding_needed() {
274        let encoding_system = Base32 {};
275        let encodable_data: &mut [EncodableData] = &mut [
276            EncodableData {
277                value: 5346,
278                num_bits: 5,
279            },
280            EncodableData {
281                value: 3474,
282                num_bits: 56,
283            },
284            EncodableData {
285                value: 0,
286                num_bits: 14,
287            },
288            EncodableData {
289                value: 0,
290                num_bits: 8,
291            },
292            EncodableData {
293                value: 46374,
294                num_bits: 83,
295            },
296        ];
297
298        assert_eq!("caaaaaaaabwjaaaaaaaaaaaaaaaaaawuta", encoding_system.encode(encodable_data));
299    }
300
301    // Some data has a value with bit size > num_bits.
302    // Total bits is not divisible by 5 and will need padding to encode.
303    #[test]
304    fn validate_encode_value_too_large_padding_needed() {
305        let encoding_system = Base32 {};
306        let encodable_data: &mut [EncodableData] = &mut [
307            EncodableData {
308                value: 2423,
309                num_bits: 5,
310            },
311            EncodableData {
312                value: 432,
313                num_bits: 3,
314            },
315            EncodableData {
316                value: 31,
317                num_bits: 12,
318            },
319            EncodableData {
320                value: 43,
321                num_bits: 10,
322            },
323            EncodableData {
324                value: 64,
325                num_bits: 6,
326            },
327        ];
328
329        assert_eq!("xaa7blaa", encoding_system.encode(encodable_data));
330    }
331
332    #[test]
333    fn validate_encode_empty_data() {
334        let encoding_system = Base32 {};
335        let encodable_data: &mut [EncodableData] = &mut [];
336
337        assert_eq!("", encoding_system.encode(encodable_data));
338    }
339
340    #[test]
341    fn validate_encode_not_enough_data_for_char() {
342        let encoding_system = Base32 {};
343        let encodable_data: &mut [EncodableData] = &mut [
344            EncodableData {
345                value: 1,
346                num_bits: 1,
347            },
348            EncodableData {
349                value: 2,
350                num_bits: 2,
351            },
352        ];
353
354        assert_eq!("y", encoding_system.encode(encodable_data));
355    }
356
357    #[test]
358    fn validate_decode_label_with_no_padding() {
359        let encoding_system = Base32 {};
360        let encodable_data = &mut [
361            EncodableData {
362                value: 0,
363                num_bits: 109,
364            },
365            EncodableData {
366                value: 0,
367                num_bits: 5,
368            },
369            EncodableData {
370                value: 0,
371                num_bits: 1,
372            },
373            EncodableData {
374                value: 0,
375                num_bits: 64,
376            },
377            EncodableData {
378                value: 0,
379                num_bits: 6,
380            },
381            EncodableData {
382                value: 0,
383                num_bits: 20,
384            },
385        ];
386        
387        match encoding_system.decode(encodable_data, b"aaaaaaaaaaaaaaaaaaaaaackvj5oaaaaaaaay5g7h", 205) {
388            Ok(()) => {
389                assert_eq!(0, encodable_data[0].value);
390                assert_eq!(1, encodable_data[1].value);
391                assert_eq!(0, encodable_data[2].value);
392                assert_eq!(6148494311290830848, encodable_data[3].value);
393                assert_eq!(24, encodable_data[4].value);
394                assert_eq!(957415, encodable_data[5].value);
395            },
396            Err(e) => panic!("Threw error when shouldn't have: {}", e.to_string())
397        };
398    }
399
400    #[test]
401    fn validate_decode_label_with_padding() {
402        let encoding_system = Base32 {};
403        let encodable_data = &mut [
404            EncodableData {
405                value: 0,
406                num_bits: 12,
407            },
408            EncodableData {
409                value: 0,
410                num_bits: 22,
411            },
412            EncodableData {
413                value: 0,
414                num_bits: 14,
415            },
416        ];
417        
418        match encoding_system.decode(encodable_data, b"ajhd6hgjh4", 48) {
419            Ok(()) => {
420                assert_eq!(36, encodable_data[0].value);
421                assert_eq!(3734643, encodable_data[1].value);
422                assert_eq!(2367, encodable_data[2].value);
423            },
424            Err(e) => panic!("Threw error when shouldn't have: {}", e.to_string())
425        };
426    }
427
428    #[test]
429    fn validate_decode_data_already_has_value() {
430        let encoding_system = Base32 {};
431        let encodable_data = &mut [
432            EncodableData {
433                value: 2423,
434                num_bits: 5,
435            },
436            EncodableData {
437                value: 53,
438                num_bits: 3,
439            },
440            EncodableData {
441                value: 43,
442                num_bits: 12,
443            },
444            EncodableData {
445                value: 754,
446                num_bits: 10,
447            },
448            EncodableData {
449                value: 34,
450                num_bits: 6,
451            },
452        ];
453        
454        match encoding_system.decode(encodable_data, b"xaa7blaa", 36) {
455            Ok(()) => {
456                assert_eq!(23, encodable_data[0].value);
457                assert_eq!(0, encodable_data[1].value);
458                assert_eq!(31, encodable_data[2].value);
459                assert_eq!(43, encodable_data[3].value);
460                assert_eq!(0, encodable_data[4].value);
461            },
462            Err(e) => panic!("Threw error when shouldn't have: {}", e.to_string())
463        };
464    }
465
466    #[test]
467    fn validate_decode_empty_label() {
468        let encoding_system = Base32 {};
469        let encodable_data = &mut [];
470        
471        match encoding_system.decode(encodable_data, b"", 0) {
472            Ok(()) => {},
473            Err(e) => panic!("Threw error when shouldn't have: {}", e.to_string())
474        };
475    }
476
477    #[test]
478    fn validate_decode_label_too_large() {
479        let encoding_system = Base32 {};
480        let encodable_data = &mut [
481            EncodableData {
482                value: 0,
483                num_bits: 12,
484            },
485            EncodableData {
486                value: 0,
487                num_bits: 22,
488            },
489            EncodableData {
490                value: 0,
491                num_bits: 14,
492            },
493        ];
494        
495        match encoding_system.decode(encodable_data, b"abacabacdfed", 46) {
496            Ok(()) => panic!("Didn't throw error when should have"),
497            Err(e) => assert_eq!("Passed 12 - expected 10 characters", e.to_string())
498        };
499    }
500
501    #[test]
502    fn validate_decode_label_too_small() {
503        let encoding_system = Base32 {};
504        let encodable_data = &mut [
505            EncodableData {
506                value: 0,
507                num_bits: 12,
508            },
509            EncodableData {
510                value: 0,
511                num_bits: 22,
512            },
513            EncodableData {
514                value: 0,
515                num_bits: 14,
516            },
517        ];
518        
519        match encoding_system.decode(encodable_data, b"aba", 46) {
520            Ok(()) => panic!("Didn't throw error when should have"),
521            Err(e) => assert_eq!("Passed 3 - expected 10 characters", e.to_string())
522        };
523    }
524}