use std::fmt;
use rust_sodium::randombytes::randombytes_into;
use serde::ser::{Serialize, Serializer};
use serde::de::{Deserialize, Deserializer, Visitor, Error as SerdeError};
use crate::helpers::libsodium_init_or_panic;
const COOKIE_BYTES: usize = 16;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub(crate) struct Cookie([u8; COOKIE_BYTES]);
impl Cookie {
pub(crate) fn new(bytes: [u8; COOKIE_BYTES]) -> Self {
Cookie(bytes)
}
pub(crate) fn random() -> Self {
libsodium_init_or_panic();
let mut rand = [0; 16];
randombytes_into(&mut rand);
assert!(!rand.iter().all(|&x| x == 0));
Cookie(rand)
}
pub(crate) fn as_bytes(&self) -> &[u8] {
&self.0
}
}
impl Serialize for Cookie {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer {
serializer.serialize_bytes(&self.0)
}
}
struct CookieVisitor;
impl<'de> Visitor<'de> for CookieVisitor {
type Value = Cookie;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("16 bytes of binary data")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E> where E: SerdeError {
if v.len() != 16 {
return Err(SerdeError::invalid_length(v.len(), &self));
}
Ok(Cookie::new([v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7],
v[8], v[9], v[10], v[11], v[12], v[13], v[14], v[15]]))
}
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E> where E: SerdeError {
self.visit_bytes(&v)
}
}
impl<'de> Deserialize<'de> for Cookie {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> {
deserializer.deserialize_bytes(CookieVisitor)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct CookiePair {
pub(crate) ours: Cookie,
pub(crate) theirs: Option<Cookie>,
}
impl CookiePair {
pub(crate) fn new() -> Self {
CookiePair {
ours: Cookie::random(),
theirs: None,
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use rmp_serde as rmps;
use super::*;
#[test]
fn random_distinct() {
let mut cookies = HashSet::new();
for _ in 0..100 {
let cookie = Cookie::random();
cookies.insert(cookie);
}
assert_eq!(cookies.len(), 100);
}
#[test]
fn cookie_serialize() {
let cookie = Cookie::new([1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8]);
let serialized = rmps::to_vec_named(&cookie).expect("Serialization failed");
assert_eq!(serialized, [
0xc4, 16, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, ]);
}
#[test]
fn cookie_deserialize() {
let cookie = Cookie::new([1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8]);
let deserialized: Cookie = rmps::from_slice(&[
0xc4, 16, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, ]).unwrap();
assert_eq!(cookie, deserialized);
}
}