use std::fmt;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::time;
use std::time::Duration;
use ring::aead;
use ring::hmac;
use self::AddressTokenType::*;
use crate::codec::Decoder;
use crate::codec::Encoder;
use crate::error::Error;
use crate::ConnectionId;
use crate::Result;
use crate::RESET_TOKEN_LEN;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum AddressTokenType {
RetryToken = 0,
ResumeToken = 1,
}
#[derive(Debug)]
pub struct AddressToken {
pub token_type: AddressTokenType,
pub issued: time::SystemTime,
pub address: SocketAddr,
pub odcid: Option<ConnectionId>,
pub rscid: Option<ConnectionId>,
}
impl AddressToken {
pub fn new_retry_token(
address: SocketAddr,
odcid: ConnectionId,
rscid: ConnectionId,
) -> AddressToken {
AddressToken {
token_type: RetryToken,
issued: time::SystemTime::now(),
address,
odcid: Some(odcid),
rscid: Some(rscid),
}
}
pub fn new_resume_token(address: SocketAddr) -> AddressToken {
AddressToken {
token_type: ResumeToken,
issued: time::SystemTime::now(),
address,
odcid: None,
rscid: None,
}
}
pub fn encode(&self, key: &aead::LessSafeKey) -> Result<Vec<u8>> {
let max_len = AddressToken::max_token_len(key);
let mut token = vec![0u8; max_len];
let nonce = rand::random::<[u8; aead::NONCE_LEN]>();
let mut buf = token.as_mut_slice();
buf.write(b"quic")?;
buf.write_u8(self.token_type as u8)?;
buf.write(&nonce)?;
let hdr_len = max_len - buf.len();
let seconds = self.issue_time()?;
buf.write_u64(seconds)?;
if self.token_type == RetryToken {
if let Some(odcid) = self.odcid {
buf.write_u8(odcid.len() as u8)?;
buf.write(&odcid)?;
} else {
return Err(Error::InternalError);
}
}
let token_len = max_len - buf.len();
let nonce = aead::Nonce::assume_unique_for_key(nonce);
let aad =
AddressToken::additional_data(self.token_type, &self.address, self.rscid.as_ref())?;
let aad = aead::Aad::from(&aad);
token.truncate(token_len);
let mut buf = token.split_off(hdr_len);
key.seal_in_place_append_tag(nonce, aad, &mut buf)
.map_err(|_| Error::InternalError)?;
token.append(&mut buf);
Ok(token)
}
pub fn decode(
key: &aead::LessSafeKey,
token: &mut [u8],
address: &SocketAddr,
pkt_dcid: &ConnectionId,
lifetime: Duration,
) -> Result<AddressToken> {
let mut buf: &[u8] = token;
let label = buf.read(4)?;
if label != b"quic" {
return Err(Error::InvalidToken);
}
let token_type = buf.read_u8()?;
let token_type = match token_type {
0 => RetryToken,
1 => ResumeToken,
_ => return Err(Error::InvalidToken),
};
let nonce = buf.read(aead::NONCE_LEN)?;
let hdr_len = token.len() - buf.len();
let rscid = if token_type == RetryToken {
Some(pkt_dcid)
} else {
None
};
let nonce =
aead::Nonce::try_assume_unique_for_key(&nonce).map_err(|_| Error::InternalError)?;
let aad = AddressToken::additional_data(token_type, address, rscid)?;
let aad = aead::Aad::from(&aad);
let buf = &mut token[hdr_len..];
key.open_in_place(nonce, aad, buf)
.map_err(|_| Error::InvalidToken)?;
let mut buf = &token[hdr_len..];
let issued = buf.read_u64()?;
let issued = match time::UNIX_EPOCH.checked_add(Duration::from_secs(issued)) {
Some(v) => v,
None => return Err(Error::InvalidToken),
};
if let Ok(duration) = issued.elapsed() {
if duration > lifetime {
return Err(Error::InvalidToken);
}
} else {
return Err(Error::InvalidToken);
}
let odcid = if token_type == RetryToken {
let cid_len = buf.read_u8()?;
match cid_len {
0 => None,
1..=20 => {
let cid = buf.read(cid_len as usize)?;
Some(ConnectionId::new(&cid))
}
_ => return Err(Error::InvalidToken),
}
} else {
Some(*pkt_dcid)
};
Ok(AddressToken {
token_type,
issued,
address: *address,
odcid,
rscid: rscid.copied(),
})
}
fn issue_time(&self) -> Result<u64> {
self.issued
.duration_since(time::UNIX_EPOCH)
.map(|x| x.as_secs())
.map_err(|_| Error::InternalError)
}
fn additional_data(
token_type: AddressTokenType,
address: &SocketAddr,
rscid: Option<&ConnectionId>,
) -> Result<Vec<u8>> {
const MAX_LEN: usize = 16 + 2 + 20;
let mut data = vec![0u8; MAX_LEN];
let mut buf = data.as_mut_slice();
let addr = match address.ip() {
IpAddr::V4(a) => a.octets().to_vec(),
IpAddr::V6(a) => a.octets().to_vec(),
};
buf.write(&addr)?;
if token_type == RetryToken {
buf.write_u16(address.port())?;
if let Some(rscid) = rscid {
buf.write(rscid)?;
} else {
return Err(Error::InternalError);
}
}
let len = MAX_LEN - buf.len();
data.truncate(len);
Ok(data)
}
fn max_token_len(key: &aead::LessSafeKey) -> usize {
4 + 1 + aead::NONCE_LEN + 8 + 21 + key.algorithm().tag_len()
}
pub fn token_type(token: &[u8]) -> Result<AddressTokenType> {
if token.len() < 5 {
return Err(Error::InvalidToken);
}
match token[4] {
0 => Ok(RetryToken),
1 => Ok(ResumeToken),
_ => Err(Error::InvalidToken),
}
}
}
#[derive(Copy, Clone, Hash, Default, PartialEq, Eq)]
pub struct ResetToken(pub [u8; RESET_TOKEN_LEN]);
impl ResetToken {
pub(crate) fn new(data: &[u8]) -> Result<Self> {
if data.len() < crate::RESET_TOKEN_LEN {
return Err(Error::BufferTooShort);
}
let mut token = ResetToken::default();
token.0.clone_from_slice(data);
Ok(token)
}
pub(crate) fn generate(key: &hmac::Key, id: &ConnectionId) -> Self {
let tag = hmac::sign(key, id);
let mut token = ResetToken::default();
token.0.clone_from_slice(&tag.as_ref()[..RESET_TOKEN_LEN]);
token
}
pub(crate) fn from_bytes(buf: &[u8]) -> Result<Self> {
if buf.len() < crate::MIN_RESET_PACKET_LEN {
return Err(Error::BufferTooShort);
}
let mut token = ResetToken::default();
token.0.copy_from_slice(&buf[buf.len() - RESET_TOKEN_LEN..]);
Ok(token)
}
pub(crate) fn to_u128(self) -> u128 {
u128::from_be_bytes(self.0)
}
pub(crate) fn from_u128(v: u128) -> ResetToken {
ResetToken(v.to_be_bytes())
}
}
impl std::ops::Deref for ResetToken {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.0
}
}
impl fmt::Debug for ResetToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for b in self.0.iter() {
write!(f, "{b:02x}")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ring::aead::LessSafeKey;
use ring::aead::UnboundKey;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
fn cmp_address_token(a: &AddressToken, b: &AddressToken) -> bool {
let duration = if a.issued > b.issued {
a.issued.duration_since(b.issued).unwrap()
} else {
b.issued.duration_since(a.issued).unwrap()
};
duration < Duration::from_secs(1)
&& a.token_type == b.token_type
&& a.address == b.address
&& a.rscid == b.rscid
&& if a.token_type == RetryToken {
a.odcid == b.odcid
} else {
true }
}
#[test]
fn address_token_normal() -> Result<()> {
let ip4 = Ipv4Addr::new(192, 168, 1, 1);
let ip6 = Ipv6Addr::new(0x26, 0, 0x1c9, 0, 0, 0xafc8, 0x10, 0x1);
let key = LessSafeKey::new(UnboundKey::new(&aead::AES_128_GCM, &[1; 16]).unwrap());
let cid0 = ConnectionId {
len: 0,
data: [0; 20],
};
let lifetime = Duration::from_secs(86400);
let retry_token_tests = [
AddressToken::new_retry_token(
SocketAddr::new(IpAddr::V4(ip4), 8888),
ConnectionId::random(),
ConnectionId::random(),
),
AddressToken::new_retry_token(
SocketAddr::new(IpAddr::V6(ip6), 8888),
ConnectionId::random(),
ConnectionId::random(),
),
AddressToken::new_retry_token(
SocketAddr::new(IpAddr::V6(ip6), 8888),
cid0,
ConnectionId::random(),
),
AddressToken::new_retry_token(
SocketAddr::new(IpAddr::V6(ip6), 8888),
ConnectionId::random(),
cid0,
),
AddressToken::new_retry_token(SocketAddr::new(IpAddr::V6(ip6), 8888), cid0, cid0),
];
for token in retry_token_tests {
let mut buf = token.encode(&key)?;
cmp_address_token(
&token,
&AddressToken::decode(
&key,
&mut buf,
&token.address,
&token.rscid.unwrap(),
lifetime,
)?,
);
}
let resume_token_tests = [
AddressToken::new_resume_token(SocketAddr::new(IpAddr::V4(ip4), 0)),
AddressToken::new_resume_token(SocketAddr::new(IpAddr::V6(ip6), 0)),
];
for token in resume_token_tests {
let mut buf = token.encode(&key)?;
cmp_address_token(
&token,
&AddressToken::decode(
&key,
&mut buf,
&token.address,
&ConnectionId::random(),
lifetime,
)?,
);
}
Ok(())
}
#[test]
fn address_token_invalid() -> Result<()> {
let key = LessSafeKey::new(UnboundKey::new(&aead::AES_128_GCM, &[1; 16]).unwrap());
let ip4 = Ipv4Addr::new(192, 168, 1, 1);
let lifetime = Duration::from_secs(86400);
for token in [
AddressToken {
token_type: RetryToken,
issued: time::SystemTime::now(),
address: SocketAddr::new(IpAddr::V4(ip4), 8888),
odcid: None,
rscid: Some(ConnectionId::random()),
},
AddressToken {
token_type: RetryToken,
issued: time::SystemTime::now(),
address: SocketAddr::new(IpAddr::V4(ip4), 8888),
odcid: Some(ConnectionId::random()),
rscid: None,
},
AddressToken {
token_type: RetryToken,
issued: time::SystemTime::now(),
address: SocketAddr::new(IpAddr::V4(ip4), 8888),
odcid: None,
rscid: None,
},
] {
assert!(token.encode(&key).is_err());
}
for (mut buf, ip) in [
(
[
0x71, 0x75, 0x69, 0x75, 0x00, 0xa3, 0x2c, 0xba, 0x33, 0x00, 0x7c, 0x54, 0xdb,
0xd3, 0xb3, 0x50, 0x0f, 0xff, 0x80, 0x2a, 0x18, 0x01, 0x4f, 0x67, 0xa1, 0x39,
0x06, 0xcf, 0x95, 0xfc, 0x2b, 0x5f, 0xf7, 0xe2, 0x34, 0x81, 0x62, 0x72, 0x79,
0xd5, 0x17, 0x18, 0x91, 0x7f, 0x56, 0x01, 0xde, 0xf6, 0x20, 0x61, 0x7c, 0xd1,
0x7c, 0x44, 0xec, 0xce, 0xeb, 0x72, 0xe6, 0x63, 0x81, 0xb2,
],
SocketAddr::new(IpAddr::V4(ip4), 8888),
),
(
[
0x71, 0x75, 0x69, 0x63, 0x02, 0xa3, 0x2c, 0xba, 0x33, 0x00, 0x7c, 0x54, 0xdb,
0xd3, 0xb3, 0x50, 0x0f, 0xff, 0x80, 0x2a, 0x18, 0x01, 0x4f, 0x67, 0xa1, 0x39,
0x06, 0xcf, 0x95, 0xfc, 0x2b, 0x5f, 0xf7, 0xe2, 0x34, 0x81, 0x62, 0x72, 0x79,
0xd5, 0x17, 0x18, 0x91, 0x7f, 0x56, 0x01, 0xde, 0xf6, 0x20, 0x61, 0x7c, 0xd1,
0x7c, 0x44, 0xec, 0xce, 0xeb, 0x72, 0xe6, 0x63, 0x81, 0xb2,
],
SocketAddr::new(IpAddr::V4(ip4), 8888),
),
(
[
0x71, 0x75, 0x69, 0x63, 0x02, 0xa3, 0x2c, 0xba, 0x33, 0x00, 0x7c, 0x54, 0xdb,
0xd3, 0xb3, 0x50, 0x0f, 0xff, 0x80, 0x2a, 0x18, 0x01, 0x4f, 0x67, 0xa1, 0x39,
0x06, 0xcf, 0x95, 0xfc, 0x2b, 0x5f, 0xf7, 0xe2, 0x34, 0x81, 0x62, 0x72, 0x79,
0xd5, 0x17, 0x18, 0x91, 0x7f, 0x56, 0x01, 0xde, 0xf6, 0x20, 0x61, 0x7c, 0xd1,
0x7c, 0x44, 0xec, 0xce, 0xeb, 0x72, 0xe6, 0x63, 0x81, 0xb2,
],
SocketAddr::new(IpAddr::V4(ip4), 8889),
),
] {
assert!(
AddressToken::decode(&key, &mut buf, &ip, &ConnectionId::random(), lifetime)
.is_err()
);
}
Ok(())
}
#[test]
fn reset_token() -> Result<()> {
let key = hmac::Key::new(hmac::HMAC_SHA256, &[]);
let c1 = ConnectionId::random();
let c2 = ConnectionId::random();
assert_eq!(
ResetToken::generate(&key, &c1),
ResetToken::generate(&key, &c1)
);
assert_ne!(
ResetToken::generate(&key, &c1),
ResetToken::generate(&key, &c2)
);
let buf = [1; crate::RESET_TOKEN_LEN - 1];
assert_eq!(ResetToken::new(&buf), Err(Error::BufferTooShort));
assert_eq!(ResetToken::from_bytes(&buf), Err(Error::BufferTooShort));
let token = ResetToken::generate(&key, &c1);
assert_eq!(ResetToken::from_u128(token.to_u128()), token);
assert_eq!(token.to_u128().to_be_bytes(), token.0);
Ok(())
}
}