hickory-proto 0.26.0

hickory-proto is a safe and secure low-level DNS library. This is the foundational DNS protocol library used by the other higher-level Hickory DNS crates.
Documentation
/*
 * Copyright (C) 2015 Benjamin Fry <benjaminfry@me.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//! bitmap for expressing the set of supported algorithms in edns.

use alloc::vec::Vec;
use core::fmt::{self, Display, Formatter};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use tracing::warn;

use super::Algorithm;
use crate::error::ProtoResult;
use crate::serialize::binary::{BinEncodable, BinEncoder};

/// Used to specify the set of SupportedAlgorithms between a client and server
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Copy, Hash)]
pub struct SupportedAlgorithms {
    // right now the number of Algorithms supported are fewer than 8.
    bit_map: u8,
}

impl SupportedAlgorithms {
    /// Return a new set of Supported algorithms
    pub fn new() -> Self {
        Self { bit_map: 0 }
    }

    /// Specify the entire set is supported
    pub fn all() -> Self {
        Self {
            bit_map: 0b0111_1111,
        }
    }

    /// Based on the set of Algorithms, return the supported set
    pub fn from_vec(algorithms: &[Algorithm]) -> Self {
        let mut supported = Self::new();

        for a in algorithms {
            supported.set(*a);
        }

        supported
    }

    fn pos(algorithm: Algorithm) -> Option<u8> {
        // not using the values from the RFC's to keep the bit_map space condensed
        #[allow(deprecated)]
        let bit_pos: Option<u8> = match algorithm {
            Algorithm::RSASHA1 => Some(0),
            Algorithm::RSASHA256 => Some(1),
            Algorithm::RSASHA1NSEC3SHA1 => Some(2),
            Algorithm::RSASHA512 => Some(3),
            Algorithm::ECDSAP256SHA256 => Some(4),
            Algorithm::ECDSAP384SHA384 => Some(5),
            Algorithm::ED25519 => Some(6),
            Algorithm::RSAMD5 | Algorithm::DSA | Algorithm::Unknown(_) => None,
        };

        bit_pos.map(|b| 1u8 << b)
    }

    fn from_pos(pos: u8) -> Option<Algorithm> {
        // TODO: should build a code generator or possibly a macro for deriving these inversions
        #[allow(deprecated)]
        match pos {
            0 => Some(Algorithm::RSASHA1),
            1 => Some(Algorithm::RSASHA256),
            2 => Some(Algorithm::RSASHA1NSEC3SHA1),
            3 => Some(Algorithm::RSASHA512),
            4 => Some(Algorithm::ECDSAP256SHA256),
            5 => Some(Algorithm::ECDSAP384SHA384),
            6 => Some(Algorithm::ED25519),
            _ => None,
        }
    }

    /// Set the specified algorithm as supported
    pub fn set(&mut self, algorithm: Algorithm) {
        if let Some(bit_pos) = Self::pos(algorithm) {
            self.bit_map |= bit_pos;
        }
    }

    /// Returns true if the algorithm is supported
    pub fn has(self, algorithm: Algorithm) -> bool {
        if let Some(bit_pos) = Self::pos(algorithm) {
            (bit_pos & self.bit_map) == bit_pos
        } else {
            false
        }
    }

    /// Return an Iterator over the supported set.
    pub fn iter(&self) -> impl Iterator<Item = Algorithm> + '_ {
        SupportedAlgorithmsIter {
            algorithms: self,
            current: 0,
        }
    }

    /// Return the count of supported algorithms
    pub fn len(self) -> u16 {
        // this is pretty much guaranteed to be less that u16::MAX
        self.iter().count() as u16
    }

    /// Return true if no SupportedAlgorithms are set, this implies the option is not supported
    pub fn is_empty(self) -> bool {
        self.bit_map == 0
    }
}

impl Default for SupportedAlgorithms {
    fn default() -> Self {
        Self::new()
    }
}

impl Display for SupportedAlgorithms {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
        for a in self.iter() {
            a.fmt(f)?;
            f.write_str(", ")?;
        }

        Ok(())
    }
}

impl<'a> From<&'a [u8]> for SupportedAlgorithms {
    fn from(values: &'a [u8]) -> Self {
        let mut supported = Self::new();

        for a in values.iter().map(|i| Algorithm::from_u8(*i)) {
            match a {
                Algorithm::Unknown(v) => warn!("unrecognized algorithm: {}", v),
                a => supported.set(a),
            }
        }

        supported
    }
}

impl<'a> From<&'a SupportedAlgorithms> for Vec<u8> {
    fn from(value: &'a SupportedAlgorithms) -> Self {
        let mut bytes = Self::with_capacity(8); // today this is less than 8

        for a in value.iter() {
            bytes.push(a.into());
        }

        bytes.shrink_to_fit();
        bytes
    }
}

impl From<Algorithm> for SupportedAlgorithms {
    fn from(algorithm: Algorithm) -> Self {
        Self::from_vec(&[algorithm])
    }
}

struct SupportedAlgorithmsIter<'a> {
    algorithms: &'a SupportedAlgorithms,
    current: usize,
}

impl Iterator for SupportedAlgorithmsIter<'_> {
    type Item = Algorithm;
    fn next(&mut self) -> Option<Self::Item> {
        // some quick bounds checking
        if self.current > u8::MAX as usize {
            return None;
        }

        while let Some(algorithm) = SupportedAlgorithms::from_pos(self.current as u8) {
            self.current += 1;
            if self.algorithms.has(algorithm) {
                return Some(algorithm);
            }
        }

        None
    }
}

impl BinEncodable for SupportedAlgorithms {
    fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
        for a in self.iter() {
            encoder.emit_u8(a.into())?;
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    #[allow(deprecated)]
    fn test_has() {
        let mut supported = SupportedAlgorithms::new();

        supported.set(Algorithm::RSASHA1);

        assert!(supported.has(Algorithm::RSASHA1));
        assert!(!supported.has(Algorithm::RSASHA1NSEC3SHA1));

        let mut supported = SupportedAlgorithms::new();

        supported.set(Algorithm::RSASHA256);
        assert!(!supported.has(Algorithm::RSASHA1));
        assert!(!supported.has(Algorithm::RSASHA1NSEC3SHA1));
        assert!(supported.has(Algorithm::RSASHA256));
    }

    #[allow(deprecated)]
    #[test]
    fn test_iterator() {
        let supported = SupportedAlgorithms::all();
        assert_eq!(supported.iter().count(), 7);

        // it just so happens that the iterator has a fixed order...
        let supported = SupportedAlgorithms::all();
        let mut iter = supported.iter();
        assert_eq!(iter.next(), Some(Algorithm::RSASHA1));
        assert_eq!(iter.next(), Some(Algorithm::RSASHA256));
        assert_eq!(iter.next(), Some(Algorithm::RSASHA1NSEC3SHA1));
        assert_eq!(iter.next(), Some(Algorithm::RSASHA512));
        assert_eq!(iter.next(), Some(Algorithm::ECDSAP256SHA256));
        assert_eq!(iter.next(), Some(Algorithm::ECDSAP384SHA384));
        assert_eq!(iter.next(), Some(Algorithm::ED25519));

        let mut supported = SupportedAlgorithms::new();
        supported.set(Algorithm::RSASHA256);
        supported.set(Algorithm::RSASHA512);

        let mut iter = supported.iter();
        assert_eq!(iter.next(), Some(Algorithm::RSASHA256));
        assert_eq!(iter.next(), Some(Algorithm::RSASHA512));
    }

    #[test]
    #[allow(deprecated)]
    fn test_vec() {
        let supported = SupportedAlgorithms::all();
        let array: Vec<u8> = (&supported).into();
        let decoded: SupportedAlgorithms = (&array as &[_]).into();

        assert_eq!(supported, decoded);

        let mut supported = SupportedAlgorithms::new();
        supported.set(Algorithm::RSASHA256);
        supported.set(Algorithm::ECDSAP256SHA256);
        supported.set(Algorithm::ECDSAP384SHA384);
        supported.set(Algorithm::ED25519);
        let array: Vec<u8> = (&supported).into();
        let decoded: SupportedAlgorithms = (&array as &[_]).into();

        assert_eq!(supported, decoded);
        assert!(!supported.has(Algorithm::RSASHA1));
        assert!(!supported.has(Algorithm::RSASHA1NSEC3SHA1));
        assert!(supported.has(Algorithm::RSASHA256));
        assert!(supported.has(Algorithm::ECDSAP256SHA256));
        assert!(supported.has(Algorithm::ECDSAP384SHA384));
        assert!(supported.has(Algorithm::ED25519));
    }
}