Skip to main content

hickory_proto/dnssec/
supported_algorithm.rs

1/*
2 * Copyright (C) 2015 Benjamin Fry <benjaminfry@me.com>
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! bitmap for expressing the set of supported algorithms in edns.
18
19use alloc::vec::Vec;
20use core::fmt::{self, Display, Formatter};
21
22#[cfg(feature = "serde")]
23use serde::{Deserialize, Serialize};
24
25use tracing::warn;
26
27use super::Algorithm;
28use crate::error::ProtoResult;
29use crate::serialize::binary::{BinEncodable, BinEncoder};
30
31/// Used to specify the set of SupportedAlgorithms between a client and server
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Copy, Hash)]
34pub struct SupportedAlgorithms {
35    // right now the number of Algorithms supported are fewer than 8.
36    bit_map: u8,
37}
38
39impl SupportedAlgorithms {
40    /// Return a new set of Supported algorithms
41    pub fn new() -> Self {
42        Self { bit_map: 0 }
43    }
44
45    /// Specify the entire set is supported
46    pub fn all() -> Self {
47        Self {
48            bit_map: 0b0111_1111,
49        }
50    }
51
52    /// Based on the set of Algorithms, return the supported set
53    pub fn from_vec(algorithms: &[Algorithm]) -> Self {
54        let mut supported = Self::new();
55
56        for a in algorithms {
57            supported.set(*a);
58        }
59
60        supported
61    }
62
63    fn pos(algorithm: Algorithm) -> Option<u8> {
64        // not using the values from the RFC's to keep the bit_map space condensed
65        #[allow(deprecated)]
66        let bit_pos: Option<u8> = match algorithm {
67            Algorithm::RSASHA1 => Some(0),
68            Algorithm::RSASHA256 => Some(1),
69            Algorithm::RSASHA1NSEC3SHA1 => Some(2),
70            Algorithm::RSASHA512 => Some(3),
71            Algorithm::ECDSAP256SHA256 => Some(4),
72            Algorithm::ECDSAP384SHA384 => Some(5),
73            Algorithm::ED25519 => Some(6),
74            Algorithm::RSAMD5 | Algorithm::DSA | Algorithm::Unknown(_) => None,
75        };
76
77        bit_pos.map(|b| 1u8 << b)
78    }
79
80    fn from_pos(pos: u8) -> Option<Algorithm> {
81        // TODO: should build a code generator or possibly a macro for deriving these inversions
82        #[allow(deprecated)]
83        match pos {
84            0 => Some(Algorithm::RSASHA1),
85            1 => Some(Algorithm::RSASHA256),
86            2 => Some(Algorithm::RSASHA1NSEC3SHA1),
87            3 => Some(Algorithm::RSASHA512),
88            4 => Some(Algorithm::ECDSAP256SHA256),
89            5 => Some(Algorithm::ECDSAP384SHA384),
90            6 => Some(Algorithm::ED25519),
91            _ => None,
92        }
93    }
94
95    /// Set the specified algorithm as supported
96    pub fn set(&mut self, algorithm: Algorithm) {
97        if let Some(bit_pos) = Self::pos(algorithm) {
98            self.bit_map |= bit_pos;
99        }
100    }
101
102    /// Returns true if the algorithm is supported
103    pub fn has(self, algorithm: Algorithm) -> bool {
104        if let Some(bit_pos) = Self::pos(algorithm) {
105            (bit_pos & self.bit_map) == bit_pos
106        } else {
107            false
108        }
109    }
110
111    /// Return an Iterator over the supported set.
112    pub fn iter(&self) -> impl Iterator<Item = Algorithm> + '_ {
113        SupportedAlgorithmsIter {
114            algorithms: self,
115            current: 0,
116        }
117    }
118
119    /// Return the count of supported algorithms
120    pub fn len(self) -> u16 {
121        // this is pretty much guaranteed to be less that u16::MAX
122        self.iter().count() as u16
123    }
124
125    /// Return true if no SupportedAlgorithms are set, this implies the option is not supported
126    pub fn is_empty(self) -> bool {
127        self.bit_map == 0
128    }
129}
130
131impl Default for SupportedAlgorithms {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl Display for SupportedAlgorithms {
138    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
139        for a in self.iter() {
140            a.fmt(f)?;
141            f.write_str(", ")?;
142        }
143
144        Ok(())
145    }
146}
147
148impl<'a> From<&'a [u8]> for SupportedAlgorithms {
149    fn from(values: &'a [u8]) -> Self {
150        let mut supported = Self::new();
151
152        for a in values.iter().map(|i| Algorithm::from_u8(*i)) {
153            match a {
154                Algorithm::Unknown(v) => warn!("unrecognized algorithm: {}", v),
155                a => supported.set(a),
156            }
157        }
158
159        supported
160    }
161}
162
163impl<'a> From<&'a SupportedAlgorithms> for Vec<u8> {
164    fn from(value: &'a SupportedAlgorithms) -> Self {
165        let mut bytes = Self::with_capacity(8); // today this is less than 8
166
167        for a in value.iter() {
168            bytes.push(a.into());
169        }
170
171        bytes.shrink_to_fit();
172        bytes
173    }
174}
175
176impl From<Algorithm> for SupportedAlgorithms {
177    fn from(algorithm: Algorithm) -> Self {
178        Self::from_vec(&[algorithm])
179    }
180}
181
182struct SupportedAlgorithmsIter<'a> {
183    algorithms: &'a SupportedAlgorithms,
184    current: usize,
185}
186
187impl Iterator for SupportedAlgorithmsIter<'_> {
188    type Item = Algorithm;
189    fn next(&mut self) -> Option<Self::Item> {
190        // some quick bounds checking
191        if self.current > u8::MAX as usize {
192            return None;
193        }
194
195        while let Some(algorithm) = SupportedAlgorithms::from_pos(self.current as u8) {
196            self.current += 1;
197            if self.algorithms.has(algorithm) {
198                return Some(algorithm);
199            }
200        }
201
202        None
203    }
204}
205
206impl BinEncodable for SupportedAlgorithms {
207    fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
208        for a in self.iter() {
209            encoder.emit_u8(a.into())?;
210        }
211        Ok(())
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    #[allow(deprecated)]
221    fn test_has() {
222        let mut supported = SupportedAlgorithms::new();
223
224        supported.set(Algorithm::RSASHA1);
225
226        assert!(supported.has(Algorithm::RSASHA1));
227        assert!(!supported.has(Algorithm::RSASHA1NSEC3SHA1));
228
229        let mut supported = SupportedAlgorithms::new();
230
231        supported.set(Algorithm::RSASHA256);
232        assert!(!supported.has(Algorithm::RSASHA1));
233        assert!(!supported.has(Algorithm::RSASHA1NSEC3SHA1));
234        assert!(supported.has(Algorithm::RSASHA256));
235    }
236
237    #[allow(deprecated)]
238    #[test]
239    fn test_iterator() {
240        let supported = SupportedAlgorithms::all();
241        assert_eq!(supported.iter().count(), 7);
242
243        // it just so happens that the iterator has a fixed order...
244        let supported = SupportedAlgorithms::all();
245        let mut iter = supported.iter();
246        assert_eq!(iter.next(), Some(Algorithm::RSASHA1));
247        assert_eq!(iter.next(), Some(Algorithm::RSASHA256));
248        assert_eq!(iter.next(), Some(Algorithm::RSASHA1NSEC3SHA1));
249        assert_eq!(iter.next(), Some(Algorithm::RSASHA512));
250        assert_eq!(iter.next(), Some(Algorithm::ECDSAP256SHA256));
251        assert_eq!(iter.next(), Some(Algorithm::ECDSAP384SHA384));
252        assert_eq!(iter.next(), Some(Algorithm::ED25519));
253
254        let mut supported = SupportedAlgorithms::new();
255        supported.set(Algorithm::RSASHA256);
256        supported.set(Algorithm::RSASHA512);
257
258        let mut iter = supported.iter();
259        assert_eq!(iter.next(), Some(Algorithm::RSASHA256));
260        assert_eq!(iter.next(), Some(Algorithm::RSASHA512));
261    }
262
263    #[test]
264    #[allow(deprecated)]
265    fn test_vec() {
266        let supported = SupportedAlgorithms::all();
267        let array: Vec<u8> = (&supported).into();
268        let decoded: SupportedAlgorithms = (&array as &[_]).into();
269
270        assert_eq!(supported, decoded);
271
272        let mut supported = SupportedAlgorithms::new();
273        supported.set(Algorithm::RSASHA256);
274        supported.set(Algorithm::ECDSAP256SHA256);
275        supported.set(Algorithm::ECDSAP384SHA384);
276        supported.set(Algorithm::ED25519);
277        let array: Vec<u8> = (&supported).into();
278        let decoded: SupportedAlgorithms = (&array as &[_]).into();
279
280        assert_eq!(supported, decoded);
281        assert!(!supported.has(Algorithm::RSASHA1));
282        assert!(!supported.has(Algorithm::RSASHA1NSEC3SHA1));
283        assert!(supported.has(Algorithm::RSASHA256));
284        assert!(supported.has(Algorithm::ECDSAP256SHA256));
285        assert!(supported.has(Algorithm::ECDSAP384SHA384));
286        assert!(supported.has(Algorithm::ED25519));
287    }
288}