use alloc::{borrow::ToOwned, vec::Vec};
use super::rdata::sig::SigInput;
use crate::{
error::{ProtoError, ProtoResult},
rr::{DNSClass, Name, Record},
serialize::binary::{BinEncodable, BinEncoder, NameEncoding},
};
pub struct TBS(Vec<u8>);
impl TBS {
pub fn from_input<'a>(
name: &Name,
dns_class: DNSClass,
input: &SigInput,
records: impl Iterator<Item = &'a Record>,
) -> ProtoResult<Self> {
Self::new(name, dns_class, input, records)
}
#[allow(clippy::too_many_arguments)]
fn new<'a>(
name: &Name,
dns_class: DNSClass,
input: &SigInput,
records: impl Iterator<Item = &'a Record>,
) -> ProtoResult<Self> {
let mut rrset = Vec::new();
for record in records {
if dns_class == record.dns_class
&& input.type_covered == record.record_type()
&& name == &record.name
{
rrset.push(record);
}
}
rrset.sort();
let name = determine_name(name, input.num_labels)?;
let mut buf = Vec::new();
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_canonical_form(true);
encoder.set_name_encoding(NameEncoding::Uncompressed);
input.emit(&mut encoder)?;
for record in rrset {
{
let mut encoder_name =
encoder.with_name_encoding(NameEncoding::UncompressedLowercase);
name.emit(&mut encoder_name)?;
}
input.type_covered.emit(&mut encoder)?;
dns_class.emit(&mut encoder)?;
encoder.emit_u32(input.original_ttl)?;
let rdata_length_place = encoder.place::<u16>()?;
record.data.emit(&mut encoder)?;
let length = u16::try_from(encoder.len_since_place(&rdata_length_place))
.map_err(|_| ProtoError::from("RDATA length exceeds u16::MAX"))?;
rdata_length_place.replace(&mut encoder, length)?;
}
Ok(Self(buf))
}
}
impl<'a> From<&'a [u8]> for TBS {
fn from(slice: &'a [u8]) -> Self {
Self(slice.to_owned())
}
}
impl AsRef<[u8]> for TBS {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
fn determine_name(name: &Name, num_labels: u8) -> Result<Name, ProtoError> {
let fqdn_labels = name.num_labels();
if fqdn_labels == num_labels {
return Ok(name.clone());
}
if num_labels < fqdn_labels {
let mut star_name: Name = Name::from_labels(vec![b"*" as &[u8]]).unwrap();
let rightmost = name.trim_to(num_labels as usize);
if !rightmost.is_root() {
star_name = star_name.append_name(&rightmost)?;
return Ok(star_name);
}
return Ok(star_name);
}
Err(format!("could not determine name from {name}").into())
}