use super::super::iana::OptionCode;
use super::super::message_builder::OptBuilder;
use super::super::wire::{Composer, ParseError};
use super::{ComposeOptData, Opt, OptData, ParseOptData};
use crate::base::Serial;
use crate::utils::base16;
use core::{fmt, hash};
use octseq::array::Array;
use octseq::builder::OctetsBuilder;
use octseq::octets::Octets;
use octseq::parse::Parser;
#[cfg_attr(
feature = "siphasher",
doc = "[`check_server_hash`](Self::check_server_hash)"
)]
#[cfg_attr(not(feature = "siphasher"), doc = "`check_server_hash`")]
#[cfg_attr(
feature = "siphasher",
doc = "[`create_response`](Self::create_response)"
)]
#[cfg_attr(not(feature = "siphasher"), doc = "`create_response`")]
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "rand", derive(Default))]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct Cookie {
client: ClientCookie,
server: Option<ServerCookie>,
}
impl Cookie {
pub(super) const CODE: OptionCode = OptionCode::COOKIE;
#[must_use]
pub fn new(client: ClientCookie, server: Option<ServerCookie>) -> Self {
Cookie { client, server }
}
#[must_use]
pub fn client(&self) -> ClientCookie {
self.client
}
#[must_use]
pub fn server(&self) -> Option<&ServerCookie> {
self.server.as_ref()
}
pub fn parse<Octs: AsRef<[u8]> + ?Sized>(
parser: &mut Parser<'_, Octs>,
) -> Result<Self, ParseError> {
Ok(Cookie::new(
ClientCookie::parse(parser)?,
ServerCookie::parse_opt(parser)?,
))
}
#[cfg(feature = "siphasher")]
pub fn check_server_hash(
&self,
client_ip: crate::base::net::IpAddr,
secret: &[u8; 16],
timestamp_ok: impl FnOnce(Serial) -> bool,
) -> bool {
self.server
.as_ref()
.and_then(|server| server.try_to_standard())
.and_then(|server| {
timestamp_ok(server.timestamp()).then_some(server)
})
.map(|server| server.check_hash(self.client(), client_ip, secret))
.unwrap_or(false)
}
#[cfg(feature = "rand")]
#[must_use]
pub fn create_initial() -> Self {
Self::new(ClientCookie::new_random(), None)
}
#[cfg(feature = "siphasher")]
pub fn create_response(
&self,
timestamp: Serial,
client_ip: crate::base::net::IpAddr,
secret: &[u8; 16],
) -> Self {
Self::new(
self.client,
Some(
StandardServerCookie::calculate(
self.client,
timestamp,
client_ip,
secret,
)
.into(),
),
)
}
pub(super) fn try_octets_from<E>(src: Self) -> Result<Self, E> {
Ok(src)
}
}
impl OptData for Cookie {
fn code(&self) -> OptionCode {
OptionCode::COOKIE
}
}
impl<'a, Octs: AsRef<[u8]> + ?Sized> ParseOptData<'a, Octs> for Cookie {
fn parse_option(
code: OptionCode,
parser: &mut Parser<'a, Octs>,
) -> Result<Option<Self>, ParseError> {
if code == OptionCode::COOKIE {
Self::parse(parser).map(Some)
} else {
Ok(None)
}
}
}
impl ComposeOptData for Cookie {
fn compose_len(&self) -> u16 {
match self.server.as_ref() {
Some(server) => ClientCookie::COMPOSE_LEN
.checked_add(server.compose_len())
.expect("long server cookie"),
None => ClientCookie::COMPOSE_LEN,
}
}
fn compose_option<Target: OctetsBuilder + ?Sized>(
&self,
target: &mut Target,
) -> Result<(), Target::AppendError> {
self.client.compose(target)?;
if let Some(server) = self.server.as_ref() {
server.compose(target)?;
}
Ok(())
}
}
impl fmt::Display for Cookie {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.client, f)?;
if let Some(server) = self.server.as_ref() {
fmt::Display::fmt(server, f)?;
}
Ok(())
}
}
impl<Octs: Octets> Opt<Octs> {
pub fn cookie(&self) -> Option<Cookie> {
self.first()
}
}
impl<Target: Composer> OptBuilder<'_, Target> {
pub fn cookie(
&mut self,
cookie: Cookie,
) -> Result<(), Target::AppendError> {
self.push(&cookie)
}
#[cfg(feature = "rand")]
pub fn initial_cookie(&mut self) -> Result<(), Target::AppendError> {
self.push(&Cookie::create_initial())
}
}
#[cfg_attr(
feature = "rand",
doc = "[`new_random`][ClientCookie::new_random]"
)]
#[cfg_attr(not(feature = "rand"), doc = "`new_random`")]
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
pub struct ClientCookie([u8; 8]);
#[cfg(feature = "serde")]
impl serde::Serialize for ClientCookie {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use octseq::serde::SerializeOctets;
self.0.serialize_octets(serializer)
}
}
impl ClientCookie {
#[must_use]
pub const fn from_octets(octets: [u8; 8]) -> Self {
Self(octets)
}
#[cfg(feature = "rand")]
#[must_use]
pub fn new_random() -> Self {
Self(rand::random())
}
#[must_use]
pub fn into_octets(self) -> [u8; 8] {
self.0
}
pub fn parse<Octs: AsRef<[u8]> + ?Sized>(
parser: &mut Parser<'_, Octs>,
) -> Result<Self, ParseError> {
let mut res = Self::from_octets([0; 8]);
parser.parse_buf(res.as_mut())?;
Ok(res)
}
pub const COMPOSE_LEN: u16 = 8;
pub fn compose<Target: OctetsBuilder + ?Sized>(
&self,
target: &mut Target,
) -> Result<(), Target::AppendError> {
target.append_slice(&self.0)
}
}
#[cfg(feature = "rand")]
impl Default for ClientCookie {
fn default() -> Self {
Self::new_random()
}
}
impl From<[u8; 8]> for ClientCookie {
fn from(src: [u8; 8]) -> Self {
Self::from_octets(src)
}
}
impl From<ClientCookie> for [u8; 8] {
fn from(src: ClientCookie) -> Self {
src.0
}
}
impl AsRef<[u8]> for ClientCookie {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl AsMut<[u8]> for ClientCookie {
fn as_mut(&mut self) -> &mut [u8] {
self.0.as_mut()
}
}
impl hash::Hash for ClientCookie {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
state.write(&self.0)
}
}
impl fmt::Display for ClientCookie {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
base16::display(self.0.as_ref(), f)
}
}
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct ServerCookie(Array<32>);
#[cfg(feature = "serde")]
impl serde::Serialize for ServerCookie {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use octseq::serde::SerializeOctets;
self.0.serialize_octets(serializer)
}
}
impl ServerCookie {
#[must_use]
pub fn from_octets(slice: &[u8]) -> Self {
assert!(slice.len() >= 8, "server cookie shorter than 8 octets");
let mut res = Array::new();
res.append_slice(slice)
.expect("server cookie longer tha 32 octets");
Self(res)
}
pub fn parse<Octs: AsRef<[u8]> + ?Sized>(
parser: &mut Parser<'_, Octs>,
) -> Result<Self, ParseError> {
if parser.remaining() < 8 {
return Err(ParseError::form_error("short server cookie"));
}
let mut res = Array::new();
res.resize_raw(parser.remaining())
.map_err(|_| ParseError::form_error("long server cookie"))?;
parser.parse_buf(res.as_slice_mut())?;
Ok(Self(res))
}
pub fn parse_opt<Octs: AsRef<[u8]> + ?Sized>(
parser: &mut Parser<'_, Octs>,
) -> Result<Option<Self>, ParseError> {
if parser.remaining() > 0 {
Self::parse(parser).map(Some)
} else {
Ok(None)
}
}
pub fn try_to_standard(&self) -> Option<StandardServerCookie> {
TryFrom::try_from(self.0.as_slice())
.map(StandardServerCookie)
.ok()
}
#[must_use]
pub fn compose_len(&self) -> u16 {
u16::try_from(self.0.len()).expect("long server cookie")
}
pub fn compose<Target: OctetsBuilder + ?Sized>(
&self,
target: &mut Target,
) -> Result<(), Target::AppendError> {
target.append_slice(self.0.as_ref())
}
}
impl From<StandardServerCookie> for ServerCookie {
fn from(src: StandardServerCookie) -> Self {
Self::from_octets(&src.0)
}
}
impl AsRef<[u8]> for ServerCookie {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl fmt::Display for ServerCookie {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
base16::display(self.0.as_ref(), f)
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct StandardServerCookie(
[u8; 16],
);
impl StandardServerCookie {
#[must_use]
pub fn new(
version: u8,
reserved: [u8; 3],
timestamp: Serial,
hash: [u8; 8],
) -> Self {
let ts = timestamp.into_int().to_be_bytes();
Self([
version,
reserved[0],
reserved[1],
reserved[2],
ts[0],
ts[1],
ts[2],
ts[3],
hash[0],
hash[1],
hash[2],
hash[3],
hash[4],
hash[5],
hash[6],
hash[7],
])
}
#[cfg(feature = "siphasher")]
pub fn calculate(
client_cookie: ClientCookie,
timestamp: Serial,
client_ip: crate::base::net::IpAddr,
secret: &[u8; 16],
) -> Self {
let mut res = Self::new(1, [0; 3], timestamp, [0; 8]);
res.set_hash(res.calculate_hash(client_cookie, client_ip, secret));
res
}
#[must_use]
pub fn version(self) -> u8 {
self.0[0]
}
#[must_use]
pub fn reserved(self) -> [u8; 3] {
TryFrom::try_from(&self.0[1..4]).expect("bad slicing")
}
#[must_use]
pub fn timestamp(self) -> Serial {
Serial::from_be_bytes(
TryFrom::try_from(&self.0[4..8]).expect("bad slicing"),
)
}
#[must_use]
pub fn hash(self) -> [u8; 8] {
TryFrom::try_from(&self.0[8..]).expect("bad slicing")
}
pub fn set_hash(&mut self, hash: [u8; 8]) {
self.0[8..].copy_from_slice(&hash);
}
#[cfg(feature = "siphasher")]
pub fn check_hash(
self,
client_cookie: ClientCookie,
client_ip: crate::base::net::IpAddr,
secret: &[u8; 16],
) -> bool {
self.calculate_hash(client_cookie, client_ip, secret) == self.hash()
}
#[cfg(feature = "siphasher")]
fn calculate_hash(
self,
client_cookie: ClientCookie,
client_ip: crate::base::net::IpAddr,
secret: &[u8; 16],
) -> [u8; 8] {
use crate::base::net::IpAddr;
use core::hash::{Hash, Hasher};
let mut hasher = siphasher::sip::SipHasher24::new_with_key(secret);
client_cookie.hash(&mut hasher);
hasher.write(&self.0[..8]);
match client_ip {
IpAddr::V4(addr) => hasher.write(&addr.octets()),
IpAddr::V6(addr) => hasher.write(&addr.octets()),
}
hasher.finish().to_le_bytes()
}
}
impl fmt::Display for StandardServerCookie {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
base16::display(self.0.as_ref(), f)
}
}
#[cfg(test)]
mod test {
#[allow(unused_imports)]
use super::*;
#[cfg(all(feature = "siphasher", feature = "std"))]
mod standard_server {
use super::*;
use crate::base::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use crate::base::wire::{compose_vec, parse_slice};
const CLIENT_1: IpAddr = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 100));
const CLIENT_2: IpAddr = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 203));
const CLIENT_6: IpAddr = IpAddr::V6(Ipv6Addr::new(
0x2001, 0xdb8, 0x220, 0x1, 0x59de, 0xd0f4, 0x8769, 0x82b8,
));
const SECRET: [u8; 16] = [
0xe5, 0xe9, 0x73, 0xe5, 0xa6, 0xb2, 0xa4, 0x3f, 0x48, 0xe7, 0xdc,
0x84, 0x9e, 0x37, 0xbf, 0xcf,
];
#[test]
fn new_cookie() {
let request = Cookie::new(
ClientCookie::from_octets([
0x24, 0x64, 0xc4, 0xab, 0xcf, 0x10, 0xc9, 0x57,
]),
None,
);
assert_eq!(
compose_vec(|vec| request.compose_option(vec)),
base16::decode_vec("2464c4abcf10c957").unwrap()
);
assert_eq!(
compose_vec(|vec| {
request
.create_response(
Serial(1559731985),
CLIENT_1,
&SECRET,
)
.compose_option(vec)
}),
base16::decode_vec(
"2464c4abcf10c957010000005cf79f111f8130c3eee29480"
)
.unwrap()
);
}
#[test]
fn renew_cookie() {
let request = parse_slice(
&base16::decode_vec(
"2464c4abcf10c957010000005cf79f111f8130c3eee29480",
)
.unwrap(),
Cookie::parse,
)
.unwrap();
assert!(request
.check_server_hash(CLIENT_1, &SECRET, |serial| serial
== Serial(1559731985)));
assert_eq!(
compose_vec(|vec| {
request
.create_response(
Serial(1559734385),
CLIENT_1,
&SECRET,
)
.compose_option(vec)
}),
base16::decode_vec(
"2464c4abcf10c957010000005cf7a871d4a564a1442aca77"
)
.unwrap()
);
}
#[test]
fn non_zero_reserved() {
let request = parse_slice(
&base16::decode_vec(
"fc93fc62807ddb8601abcdef5cf78f71a314227b6679ebf5",
)
.unwrap(),
Cookie::parse,
)
.unwrap();
assert!(request
.check_server_hash(CLIENT_2, &SECRET, |serial| serial
== Serial(1559727985)));
assert_eq!(
compose_vec(|vec| {
request
.create_response(
Serial(1559734700),
CLIENT_2,
&SECRET,
)
.compose_option(vec)
}),
base16::decode_vec(
"fc93fc62807ddb86010000005cf7a9acf73a7810aca2381e"
)
.unwrap()
);
}
#[test]
fn new_secret() {
const OLD_SECRET: [u8; 16] = [
0xdd, 0x3b, 0xdf, 0x93, 0x44, 0xb6, 0x78, 0xb1, 0x85, 0xa6,
0xf5, 0xcb, 0x60, 0xfc, 0xa7, 0x15,
];
const NEW_SECRET: [u8; 16] = [
0x44, 0x55, 0x36, 0xbc, 0xd2, 0x51, 0x32, 0x98, 0x07, 0x5a,
0x5d, 0x37, 0x96, 0x63, 0xc9, 0x62,
];
let request = parse_slice(
&base16::decode_vec(
"22681ab97d52c298010000005cf7c57926556bd0934c72f8",
)
.unwrap(),
Cookie::parse,
)
.unwrap();
assert!(!request.check_server_hash(
CLIENT_6,
&NEW_SECRET,
|serial| serial == Serial(1559741817)
));
assert!(request.check_server_hash(
CLIENT_6,
&OLD_SECRET,
|serial| serial == Serial(1559741817)
));
assert_eq!(
compose_vec(|vec| {
request
.create_response(
Serial(1559741961),
CLIENT_6,
&NEW_SECRET,
)
.compose_option(vec)
}),
base16::decode_vec(
"22681ab97d52c298010000005cf7c609a6bb79d16625507a"
)
.unwrap()
);
}
}
}