use alloc::{borrow::ToOwned, string::String, sync::Arc, vec::Vec};
use core::{fmt, str::FromStr};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
dnssec::{
Algorithm, DigestType, PublicKey, PublicKeyBuf, Verifier,
crypto::{Digest, decode_public_key},
},
error::ProtoResult,
rr::{Name, RecordData, RecordDataDecodable, RecordType, record_data::RData},
serialize::{
binary::{
BinDecodable, BinDecoder, BinEncodable, BinEncoder, DecodeError, NameEncoding,
Restrict, RestrictedMath,
},
txt::ParseError,
},
};
use super::DNSSECRData;
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct DNSKEY {
flags: u16,
public_key: PublicKeyBuf,
}
impl DNSKEY {
pub fn from_key(public_key: &dyn PublicKey) -> Self {
Self::new(
true,
true,
false,
PublicKeyBuf::new(public_key.public_bytes().to_owned(), public_key.algorithm()),
)
}
pub fn new(
zone_key: bool,
secure_entry_point: bool,
revoke: bool,
public_key: PublicKeyBuf,
) -> Self {
let mut flags: u16 = 0;
if zone_key {
flags |= 0b0000_0001_0000_0000;
}
if secure_entry_point {
flags |= 0b0000_0000_0000_0001;
}
if revoke {
flags |= 0b0000_0000_1000_0000;
}
Self::with_flags(flags, public_key)
}
pub(crate) fn from_tokens<'i>(
mut tokens: impl Iterator<Item = &'i str>,
) -> Result<Self, ParseError> {
let flags_str = tokens
.next()
.ok_or(ParseError::Message("flags not present"))?;
let protocol_str = tokens
.next()
.ok_or(ParseError::Message("protocol not present"))?;
let algorithm_str = tokens
.next()
.ok_or(ParseError::Message("algorithm not present"))?;
let flags = u16::from_str(flags_str)?;
let protocol = u8::from_str(protocol_str)?;
if protocol != 3 {
return Err(ParseError::Message("protocol field must be 3"));
}
let algorithm = Algorithm::from_u8(algorithm_str.parse()?);
let public_key_str = tokens.collect::<String>();
if public_key_str.is_empty() {
return Err(ParseError::Message("public key not present"));
}
let public_key = data_encoding::BASE64.decode(public_key_str.as_bytes())?;
Ok(Self::with_flags(
flags,
PublicKeyBuf::new(public_key, algorithm),
))
}
pub fn with_flags(flags: u16, public_key: PublicKeyBuf) -> Self {
Self { flags, public_key }
}
pub fn zone_key(&self) -> bool {
self.flags & 0b0000_0001_0000_0000 != 0
}
pub fn secure_entry_point(&self) -> bool {
self.flags & 0b0000_0000_0000_0001 != 0
}
pub fn is_key_signing_key(&self) -> bool {
self.secure_entry_point() && self.zone_key() && !self.revoke()
}
pub fn revoke(&self) -> bool {
self.flags & 0b0000_0000_1000_0000 != 0
}
pub fn public_key(&self) -> &PublicKeyBuf {
&self.public_key
}
pub fn flags(&self) -> u16 {
self.flags
}
pub fn to_digest(&self, name: &Name, digest_type: DigestType) -> ProtoResult<Digest> {
let mut buf: Vec<u8> = Vec::new();
{
let mut encoder: BinEncoder<'_> = BinEncoder::new(&mut buf);
encoder.set_name_encoding(NameEncoding::UncompressedLowercase);
if let Err(e) = name
.to_lowercase()
.emit(&mut encoder)
.and_then(|_| self.emit(&mut encoder))
{
tracing::warn!("error serializing dnskey: {e}");
return Err(format!("error serializing dnskey: {e}").into());
}
}
Ok(Digest::new(&buf, digest_type)?)
}
pub fn calculate_key_tag(&self) -> ProtoResult<u16> {
let mut bytes: Vec<u8> = Vec::with_capacity(512);
{
let mut e = BinEncoder::new(&mut bytes);
self.emit(&mut e)?;
}
Ok(Self::calculate_key_tag_internal(&bytes))
}
pub fn calculate_key_tag_internal(bytes: &[u8]) -> u16 {
let mut ac: u32 = 0;
for (i, k) in bytes.iter().enumerate() {
ac += u32::from(*k) << if i & 0x01 != 0 { 0 } else { 8 };
}
ac += ac >> 16;
(ac & 0xFFFF) as u16
}
}
impl From<DNSKEY> for RData {
fn from(key: DNSKEY) -> Self {
Self::DNSSEC(DNSSECRData::DNSKEY(key))
}
}
impl BinEncodable for DNSKEY {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
encoder.emit_u16(self.flags())?;
encoder.emit(3)?; self.public_key.algorithm().emit(encoder)?;
encoder.emit_vec(self.public_key.public_bytes())?;
Ok(())
}
}
impl<'r> RecordDataDecodable<'r> for DNSKEY {
fn read_data(decoder: &mut BinDecoder<'r>, length: Restrict<u16>) -> Result<Self, DecodeError> {
let flags: u16 = decoder.read_u16()?.unverified();
let _protocol: u8 = decoder
.read_u8()?
.verify_unwrap(|protocol| {
*protocol == 3
})
.map_err(DecodeError::DnsKeyProtocolNot3)?;
let algorithm: Algorithm = Algorithm::read(decoder)?;
let key_len = length
.map(|u| u as usize)
.checked_sub(4)
.map_err(|len| DecodeError::IncorrectRDataLengthRead { read: 4, len })?
.unverified();
let public_key =
decoder.read_vec(key_len)?.unverified();
Ok(Self::with_flags(
flags,
PublicKeyBuf::new(public_key, algorithm),
))
}
}
impl RecordData for DNSKEY {
fn try_borrow(data: &RData) -> Option<&Self> {
match data {
RData::DNSSEC(DNSSECRData::DNSKEY(csync)) => Some(csync),
_ => None,
}
}
fn record_type(&self) -> RecordType {
RecordType::DNSKEY
}
fn into_rdata(self) -> RData {
RData::DNSSEC(DNSSECRData::DNSKEY(self))
}
}
impl Verifier for DNSKEY {
fn algorithm(&self) -> Algorithm {
self.public_key.algorithm()
}
fn key(&self) -> ProtoResult<Arc<dyn PublicKey + '_>> {
decode_public_key(self.public_key.public_bytes(), self.public_key.algorithm())
}
}
impl fmt::Display for DNSKEY {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
f,
"{flags} 3 {alg} {key}",
flags = self.flags(),
alg = u8::from(self.public_key.algorithm()),
key = data_encoding::BASE64.encode(self.public_key.public_bytes())
)
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use alloc::string::ToString;
#[cfg(feature = "std")]
use std::println;
use rustls_pki_types::PrivateKeyDer;
use super::*;
use crate::dnssec::{SigningKey, crypto::EcdsaSigningKey};
#[test]
fn test() {
let algorithm = Algorithm::ECDSAP256SHA256;
let pkcs8 = EcdsaSigningKey::generate_pkcs8(algorithm).unwrap();
let signing_key =
EcdsaSigningKey::from_key_der(&PrivateKeyDer::from(pkcs8), algorithm).unwrap();
let rdata = DNSKEY::new(
true,
true,
false,
PublicKeyBuf::new(
signing_key
.to_public_key()
.unwrap()
.public_bytes()
.to_owned(),
algorithm,
),
);
let mut bytes = Vec::new();
let mut encoder: BinEncoder<'_> = BinEncoder::new(&mut bytes);
assert!(rdata.emit(&mut encoder).is_ok());
let bytes = encoder.into_bytes();
#[cfg(feature = "std")]
println!("bytes: {bytes:?}");
let mut decoder: BinDecoder<'_> = BinDecoder::new(bytes);
let read_rdata = DNSKEY::read_data(&mut decoder, Restrict::new(bytes.len() as u16));
let read_rdata = read_rdata.expect("error decoding");
assert_eq!(rdata, read_rdata);
assert!(
rdata
.to_digest(
&Name::parse("www.example.com.", None).unwrap(),
DigestType::SHA256
)
.is_ok()
);
}
#[test]
fn test_reserved_flags() {
let rdata =
DNSKEY::with_flags(u16::MAX, PublicKeyBuf::new(vec![0u8], Algorithm::RSASHA256));
let mut bytes = Vec::new();
let mut encoder = BinEncoder::new(&mut bytes);
rdata.emit(&mut encoder).expect("error encoding");
let bytes = encoder.into_bytes();
println!("bytes: {bytes:?}");
let mut decoder = BinDecoder::new(bytes);
let read_rdata = DNSKEY::read_data(&mut decoder, Restrict::new(bytes.len() as u16))
.expect("error decoding");
assert_eq!(rdata, read_rdata);
}
#[test]
fn test_calculate_key_tag_checksum() {
let test_text = "The quick brown fox jumps over the lazy dog";
let test_vectors = vec![
(vec![], 0),
(vec![0, 0, 0, 0], 0),
(vec![0xff, 0xff, 0xff, 0xff], 0xffff),
(vec![1, 0, 0, 0], 0x0100),
(vec![0, 1, 0, 0], 0x0001),
(vec![0, 0, 1, 0], 0x0100),
(test_text.as_bytes().to_vec(), 0x8d5b),
];
for (input_data, exp_result) in test_vectors {
let result = DNSKEY::calculate_key_tag_internal(&input_data);
assert_eq!(result, exp_result);
}
}
const ENCODED: &str = "aGVsbG8=";
#[test]
fn accepts_real_world_data() {
let trust_anchor = include_str!("../../../tests/test-data/root.key");
let mut did_parse = false;
for line in trust_anchor.lines() {
if line.trim_start().starts_with(';') {
continue;
}
let parts = line.split_whitespace().skip(4);
DNSKEY::from_tokens(parts).expect("could not parse");
did_parse = true;
}
assert!(did_parse);
}
#[cfg(feature = "__dnssec")]
#[test]
fn it_works() {
let algorithm = Algorithm::ECDSAP256SHA256;
let pkcs8 = EcdsaSigningKey::generate_pkcs8(algorithm).unwrap();
let signing_key = EcdsaSigningKey::from_pkcs8(&pkcs8, algorithm).unwrap();
let public_key = signing_key.to_public_key().unwrap();
let encoded = data_encoding::BASE64.encode(public_key.public_bytes());
let input = format!("256 3 13 {encoded}");
let expected = DNSKEY::new(
true,
false,
false,
PublicKeyBuf::new(
signing_key.to_public_key().unwrap().public_bytes().to_vec(),
algorithm,
),
);
assert_eq!(expected, parse_ok(&input),);
}
#[cfg(feature = "__dnssec")]
#[test]
fn secure_entry_point() {
let algorithm = Algorithm::ECDSAP256SHA256;
let pkcs8 = EcdsaSigningKey::generate_pkcs8(algorithm).unwrap();
let signing_key = EcdsaSigningKey::from_pkcs8(&pkcs8, algorithm).unwrap();
let public_key = signing_key.to_public_key().unwrap();
let encoded = data_encoding::BASE64.encode(public_key.public_bytes());
let input = format!("257 3 13 {encoded}");
let expected = DNSKEY::new(
true,
true,
false,
PublicKeyBuf::new(
signing_key.to_public_key().unwrap().public_bytes().to_vec(),
algorithm,
),
);
assert_eq!(expected, parse_ok(&input),);
}
#[test]
fn incomplete() {
let cases = ["", "256", "256 3", "256 3 8"];
for case in cases {
let err = parse_err(case);
assert!(err.to_string().contains("not present"))
}
}
#[test]
fn reserved_flags() {
let public_key = PublicKeyBuf::new(b"hello".to_vec(), Algorithm::RSASHA256);
let expected = DNSKEY::with_flags(2, public_key);
assert_eq!(expected, parse_ok(&format!("2 3 8 {ENCODED}")));
}
#[test]
fn bad_protocol() {
let err = parse_err(&format!("256 0 8 {ENCODED}"));
assert!(err.to_string().contains("protocol field"))
}
#[test]
fn bad_public_key() {
let mut input = format!("256 3 8 {ENCODED}");
input.pop().unwrap(); let err = parse_err(&input);
assert!(err.to_string().contains("data encoding error"))
}
fn parse_ok(input: &str) -> DNSKEY {
DNSKEY::from_tokens(input.split_whitespace()).expect("parsing failed")
}
fn parse_err(input: &str) -> ParseError {
DNSKEY::from_tokens(input.split_whitespace()).expect_err("parsing did not fail")
}
}