use crate::{Error, Keypair, PublicKey, Result};
use bytes::{Bytes, BytesMut};
use ed25519_dalek::Signature;
#[cfg(feature = "dht")]
use mainline::common::MutableItem;
use self_cell::self_cell;
use simple_dns::{
rdata::{RData, A, AAAA},
Name, Packet, ResourceRecord,
};
use std::{
char,
fmt::{self, Display, Formatter},
net::{Ipv4Addr, Ipv6Addr},
time::{Duration, SystemTime},
};
const DOT: char = '.';
self_cell!(
struct Inner {
owner: Bytes,
#[covariant]
dependent: InnerParsed,
}
impl{Debug}
);
#[derive(Debug)]
struct InnerParsed<'a> {
public_key: PublicKey,
timestamp: u64,
signature: Signature,
packet: Packet<'a>,
}
impl Inner {
fn try_from_response(public_key: PublicKey, response: Bytes) -> Result<Self> {
Self::try_new(response, |response| {
if response.len() < 72 {
return Err(Error::InvalidSingedPacketBytes(response.len()));
}
if response.len() > 1072 {
return Err(Error::PacketTooLarge(response.len()));
}
let signature =
Signature::from_bytes(response[..64].try_into().expect("signature is 64 bytes"));
let timestamp =
u64::from_be_bytes(response[64..72].try_into().expect("seq is 8 bytes"));
let encoded_packet = &response.slice(72..);
public_key.verify(&signable(timestamp, encoded_packet), &signature)?;
match Packet::parse(&response[72..]) {
Ok(packet) => Ok(InnerParsed {
public_key,
timestamp,
signature,
packet,
}),
Err(e) => Err(e.into()),
}
})
}
fn try_from_parts(
public_key: PublicKey,
encoded_packet: Bytes,
timestamp: u64,
signature: Signature,
) -> Result<Self> {
let mut bytes = BytesMut::with_capacity(encoded_packet.len() + 104);
bytes.extend_from_slice(&public_key.to_bytes());
bytes.extend_from_slice(&signature.to_bytes());
bytes.extend_from_slice(×tamp.to_be_bytes());
bytes.extend_from_slice(&encoded_packet);
Self::try_new(bytes.into(), |bytes| match Packet::parse(&bytes[104..]) {
Ok(packet) => Ok(InnerParsed {
public_key,
timestamp,
signature,
packet,
}),
Err(e) => Err(e),
})
.map_err(|e| e.into())
}
}
#[derive(Debug)]
pub struct SignedPacket {
inner: Inner,
}
impl SignedPacket {
pub fn from_relay_response(public_key: PublicKey, response: Bytes) -> Result<SignedPacket> {
let inner = Inner::try_from_response(public_key, response)?;
Ok(SignedPacket { inner })
}
pub fn as_relay_request(&self) -> Bytes {
self.inner.borrow_owner().slice(32..)
}
pub fn from_packet(keypair: &Keypair, packet: &Packet) -> Result<SignedPacket> {
let mut inner = Packet::new_reply(0);
let origin = keypair.public_key().to_z32();
let normalized_names: Vec<String> = packet
.answers
.iter()
.map(|answer| normalize_name(&origin, answer.name.to_string()))
.collect();
packet
.answers
.iter()
.enumerate()
.for_each(|(index, answer)| {
let new_new_name = Name::new_unchecked(&normalized_names[index]);
inner.answers.push(ResourceRecord::new(
new_new_name.clone(),
answer.class,
answer.ttl,
answer.rdata.clone(),
))
});
let encoded_packet: Bytes = inner.build_bytes_vec_compressed()?.into();
if encoded_packet.len() > 1000 {
return Err(Error::PacketTooLarge(encoded_packet.len()));
}
let timestamp = system_time().as_micros() as u64;
let signature = keypair.sign(&signable(timestamp, &encoded_packet));
Ok(SignedPacket {
inner: Inner::try_from_parts(
keypair.public_key(),
encoded_packet,
timestamp,
signature,
)?,
})
}
pub fn public_key(&self) -> &PublicKey {
&self.inner.borrow_dependent().public_key
}
pub fn timestamp(&self) -> &u64 {
&self.inner.borrow_dependent().timestamp
}
pub fn packet(&self) -> &Packet {
&self.inner.borrow_dependent().packet
}
pub fn signature(&self) -> &Signature {
&self.inner.borrow_dependent().signature
}
pub fn encoded_packet(&self) -> Bytes {
self.inner.borrow_owner().slice(104..)
}
pub fn more_recent_than(&self, other: &SignedPacket) -> bool {
if self.timestamp() < other.timestamp() {
return false;
}
if self.timestamp() == other.timestamp() && self.encoded_packet() < other.encoded_packet() {
return false;
}
true
}
pub fn resource_records(&self, name: &str) -> impl Iterator<Item = &ResourceRecord> {
let origin = self.public_key().to_z32();
let normalized_name = normalize_name(&origin, name.to_string());
self.packet()
.answers
.iter()
.filter(move |rr| rr.name == Name::new(&normalized_name).unwrap())
}
pub fn elapsed(&self) -> Duration {
system_time() - Duration::from_micros(*self.timestamp())
}
pub fn fresh_resource_records(&self, name: &str) -> impl Iterator<Item = &ResourceRecord> {
let origin = self.public_key().to_z32();
let normalized_name = normalize_name(&origin, name.to_string());
let elapsed = self.elapsed().as_secs() as u32;
self.packet()
.answers
.iter()
.filter(move |rr| rr.name == Name::new(&normalized_name).unwrap() && rr.ttl > elapsed)
}
}
fn signable(timestamp: u64, v: &Bytes) -> Bytes {
let mut signable = format!("3:seqi{}e1:v{}:", timestamp, v.len()).into_bytes();
signable.extend(v);
signable.into()
}
fn system_time() -> Duration {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("time drift")
}
#[cfg(feature = "dht")]
impl From<&SignedPacket> for MutableItem {
fn from(s: &SignedPacket) -> Self {
let seq: i64 = *s.timestamp() as i64;
let packet = s.inner.borrow_owner().slice(104..);
Self::new_signed_unchecked(
s.public_key().to_bytes(),
s.signature().to_bytes(),
packet,
seq,
None,
)
}
}
#[cfg(feature = "dht")]
impl TryFrom<MutableItem> for SignedPacket {
type Error = Error;
fn try_from(i: MutableItem) -> Result<Self> {
let public_key: PublicKey = i.key().to_owned().try_into().unwrap();
let encoded_packet: Bytes = i.value().to_vec().into();
let seq = i.seq().to_owned() as u64;
let signature: Signature = i.signature().into();
Ok(Self {
inner: Inner::try_from_parts(public_key, encoded_packet, seq, signature)?,
})
}
}
impl AsRef<[u8]> for SignedPacket {
fn as_ref(&self) -> &[u8] {
self.inner.borrow_owner()
}
}
impl Display for SignedPacket {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"SignedPacket ({}):\n timestamp: {},\n signature: {}\n records:\n",
&self.public_key(),
&self.timestamp(),
&self.signature(),
)?;
for answer in &self.packet().answers {
writeln!(
f,
" {} IN {} {}\n",
&answer.name,
&answer.ttl,
match &answer.rdata {
RData::A(A { address }) => format!("A {}", Ipv4Addr::from(*address)),
RData::AAAA(AAAA { address }) => format!("AAAA {}", Ipv6Addr::from(*address)),
#[allow(clippy::to_string_in_format_args)]
RData::CNAME(name) => format!("CNAME {}", name.to_string()),
RData::TXT(txt) => {
format!(
"TXT \"{}\"",
txt.clone()
.try_into()
.unwrap_or("__INVALID_TXT_VALUE_".to_string())
)
}
_ => format!("{:?}", answer.rdata),
}
)?;
}
writeln!(f)?;
Ok(())
}
}
fn normalize_name(origin: &str, name: String) -> String {
let name = if name.ends_with(DOT) {
name[..name.len() - 1].to_string()
} else {
name
};
let parts: Vec<&str> = name.split('.').collect();
let last = *parts.last().unwrap_or(&"");
if last == origin {
return name.to_string();
} else if last == "@" || last.is_empty() {
return origin.to_string();
}
format!("{}.{}", name, origin)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns;
#[test]
fn normalize_names() {
let origin = "ed4mn3aoazuf1ahpy9rz1nyswhukbj5483ryefwkue7fbp3egkzo";
assert_eq!(normalize_name(origin, ".".to_string()), origin);
assert_eq!(normalize_name(origin, "@".to_string()), origin);
assert_eq!(normalize_name(origin, "@.".to_string()), origin);
assert_eq!(normalize_name(origin, origin.to_string()), origin);
assert_eq!(
normalize_name(origin, "_derp_region.irorh".to_string()),
format!("_derp_region.irorh.{}", origin)
);
assert_eq!(
normalize_name(origin, format!("_derp_region.irorh.{}", origin)),
format!("_derp_region.irorh.{}", origin)
);
assert_eq!(
normalize_name(origin, format!("_derp_region.irorh.{}.", origin)),
format!("_derp_region.irorh.{}", origin)
);
}
#[test]
fn sign_verify() {
let keypair = Keypair::random();
let mut packet = Packet::new_reply(0);
packet.answers.push(ResourceRecord::new(
Name::new("_derp_region.iroh.").unwrap(),
simple_dns::CLASS::IN,
30,
RData::A(A {
address: Ipv4Addr::new(1, 1, 1, 1).into(),
}),
));
let signed_packet = SignedPacket::from_packet(&keypair, &packet).unwrap();
assert!(SignedPacket::from_relay_response(
signed_packet.public_key().clone(),
signed_packet.as_relay_request()
)
.is_ok());
}
#[test]
fn from_too_large_bytes() {
let keypair = Keypair::random();
let bytes = Bytes::from(vec![0; 1073]);
let error = SignedPacket::from_relay_response(keypair.public_key().clone(), bytes);
assert!(error.is_err());
}
#[test]
fn from_too_large_packet() {
let keypair = Keypair::random();
let mut packet = Packet::new_reply(0);
for _ in 0..100 {
packet.answers.push(ResourceRecord::new(
Name::new("_derp_region.iroh.").unwrap(),
simple_dns::CLASS::IN,
30,
RData::A(A {
address: Ipv4Addr::new(1, 1, 1, 1).into(),
}),
));
}
let error = SignedPacket::from_packet(&keypair, &packet);
assert!(error.is_err());
}
#[test]
fn resource_records_iterator() {
let keypair = Keypair::random();
let target = ResourceRecord::new(
Name::new("_derp_region.iroh.").unwrap(),
simple_dns::CLASS::IN,
30,
RData::A(A {
address: Ipv4Addr::new(1, 1, 1, 1).into(),
}),
);
let mut packet = Packet::new_reply(0);
packet.answers.push(target.clone());
packet.answers.push(ResourceRecord::new(
Name::new("something else").unwrap(),
simple_dns::CLASS::IN,
30,
RData::A(A {
address: Ipv4Addr::new(1, 1, 1, 1).into(),
}),
));
let signed_packet = SignedPacket::from_packet(&keypair, &packet).unwrap();
let iter = signed_packet.resource_records("_derp_region.iroh");
assert_eq!(iter.count(), 1);
for record in signed_packet.resource_records("_derp_region.iroh") {
assert_eq!(record.rdata, target.rdata);
}
}
#[test]
fn to_mutable() {
let keypair = Keypair::random();
let mut packet = Packet::new_reply(0);
packet.answers.push(ResourceRecord::new(
Name::new("_derp_region.iroh.").unwrap(),
simple_dns::CLASS::IN,
30,
RData::A(A {
address: Ipv4Addr::new(1, 1, 1, 1).into(),
}),
));
let signed_packet = SignedPacket::from_packet(&keypair, &packet).unwrap();
let item: MutableItem = (&signed_packet).into();
let seq: i64 = *signed_packet.timestamp() as i64;
let expected = MutableItem::new(
keypair.secret_key().into(),
signed_packet
.packet()
.build_bytes_vec_compressed()
.unwrap()
.into(),
seq,
None,
);
assert_eq!(item, expected);
}
#[test]
fn compressed_names() {
let keypair = Keypair::random();
let name = "foobar";
let dup = name;
let mut packet = Packet::new_reply(0);
packet.answers.push(dns::ResourceRecord::new(
dns::Name::new("@").unwrap(),
dns::CLASS::IN,
30,
dns::rdata::RData::CNAME(dns::Name::new(name).unwrap().into()),
));
packet.answers.push(dns::ResourceRecord::new(
dns::Name::new("@").unwrap(),
dns::CLASS::IN,
30,
dns::rdata::RData::CNAME(dns::Name::new(dup).unwrap().into()),
));
let signed = SignedPacket::from_packet(&keypair, &packet).unwrap();
assert_eq!(
signed
.resource_records("@")
.map(|r| r.rdata.clone())
.collect::<Vec<_>>(),
packet
.answers
.iter()
.map(|r| r.rdata.clone())
.collect::<Vec<_>>()
)
}
}