use crate::error::SessionError;
use aead::{generic_array::GenericArray, Aead, NewAead};
use aes_gcm::Aes256Gcm;
use chacha20poly1305::ChaCha20Poly1305;
use rand::rngs::OsRng;
use rand::RngCore;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use time::OffsetDateTime;
#[derive(Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Session<V> {
pub expires: Option<OffsetDateTime>,
pub value: Option<V>,
}
pub trait SessionManager<V: Serialize + DeserializeOwned>: Send + Sync {
fn deserialize(&self, bytes: &[u8]) -> Result<Session<V>, SessionError>;
fn serialize(&self, session: &Session<V>) -> Result<Vec<u8>, SessionError>;
fn is_encrypted(&self) -> bool;
}
pub struct ChaCha20Poly1305SessionManager<V: Serialize + DeserializeOwned> {
aead_key: [u8; 32],
_value: PhantomData<V>,
}
impl<V: Serialize + DeserializeOwned> ChaCha20Poly1305SessionManager<V> {
pub fn from_key(aead_key: [u8; 32]) -> Self {
ChaCha20Poly1305SessionManager {
aead_key: aead_key,
_value: PhantomData,
}
}
fn random_bytes(&self, buf: &mut [u8]) {
OsRng.fill_bytes(buf);
}
fn aead(&self) -> ChaCha20Poly1305 {
ChaCha20Poly1305::new(&GenericArray::clone_from_slice(&self.aead_key))
}
}
impl<V: Serialize + DeserializeOwned + Send + Sync> SessionManager<V>
for ChaCha20Poly1305SessionManager<V>
{
fn deserialize(&self, bytes: &[u8]) -> Result<Session<V>, SessionError> {
if bytes.len() <= 60 {
return Err(SessionError::ValidationError);
}
let nonce = GenericArray::from_slice(&bytes[0..12]);
let plaintext = self
.aead()
.decrypt(&nonce, bytes[12..].as_ref())
.map_err(|_| SessionError::InternalError)?;
serde_cbor::from_slice(&plaintext[32..plaintext.len()]).map_err(|err| {
warn!("Failed to deserialize session: {}", err);
SessionError::InternalError
})
}
fn serialize(&self, session: &Session<V>) -> Result<Vec<u8>, SessionError> {
let session_bytes = serde_cbor::to_vec(&session).map_err(|err| {
warn!("Failed to serialize session: {}", err);
SessionError::InternalError
})?;
let mut padding = [0; 32];
self.random_bytes(&mut padding);
let mut plaintext = vec![0; session_bytes.len() + 32];
plaintext[0..32].copy_from_slice(&padding);
plaintext[32..].copy_from_slice(&session_bytes);
let mut nonce = [0; 12];
self.random_bytes(&mut nonce);
let nonce = GenericArray::from_slice(&nonce);
let ciphertext = self
.aead()
.encrypt(&nonce, plaintext.as_ref())
.map_err(|_| SessionError::InternalError)?;
let mut transport = vec![0; ciphertext.len() + 12];
transport[0..12].copy_from_slice(&nonce);
transport[12..].copy_from_slice(&ciphertext);
Ok(transport)
}
fn is_encrypted(&self) -> bool {
true
}
}
pub struct AesGcmSessionManager<V: Serialize + DeserializeOwned> {
aead_key: [u8; 32],
_value: PhantomData<V>,
}
impl<V: Serialize + DeserializeOwned> AesGcmSessionManager<V> {
pub fn from_key(aead_key: [u8; 32]) -> Self {
AesGcmSessionManager {
aead_key: aead_key,
_value: PhantomData,
}
}
fn random_bytes(&self, buf: &mut [u8]) {
OsRng.fill_bytes(buf);
}
fn aead(&self) -> Aes256Gcm {
Aes256Gcm::new(&GenericArray::clone_from_slice(&self.aead_key))
}
}
impl<V: Serialize + DeserializeOwned + Send + Sync> SessionManager<V> for AesGcmSessionManager<V> {
fn deserialize(&self, bytes: &[u8]) -> Result<Session<V>, SessionError> {
if bytes.len() <= 60 {
return Err(SessionError::ValidationError);
}
let nonce = GenericArray::from_slice(&bytes[0..12]);
let plaintext = self
.aead()
.decrypt(&nonce, bytes[12..].as_ref())
.map_err(|_| SessionError::InternalError)?;
serde_cbor::from_slice(&plaintext[32..plaintext.len()]).map_err(|err| {
warn!("Failed to deserialize session: {}", err);
SessionError::InternalError
})
}
fn serialize(&self, session: &Session<V>) -> Result<Vec<u8>, SessionError> {
let session_bytes = serde_cbor::to_vec(&session).map_err(|err| {
warn!("Failed to serialize session: {}", err);
SessionError::InternalError
})?;
let mut padding = [0; 32];
self.random_bytes(&mut padding);
let mut plaintext = vec![0; session_bytes.len() + 32];
plaintext[0..32].copy_from_slice(&padding);
plaintext[32..].copy_from_slice(&session_bytes);
let mut nonce = [0; 12];
self.random_bytes(&mut nonce);
let nonce = GenericArray::from_slice(&nonce);
let ciphertext = self
.aead()
.encrypt(&nonce, plaintext.as_ref())
.map_err(|_| SessionError::InternalError)?;
let mut transport = vec![0; ciphertext.len() + 12];
transport[0..12].copy_from_slice(&nonce);
transport[12..].copy_from_slice(&ciphertext);
Ok(transport)
}
fn is_encrypted(&self) -> bool {
true
}
}
pub struct MultiSessionManager<V: Serialize + DeserializeOwned + Send + Sync> {
current: Box<dyn SessionManager<V>>,
previous: Vec<Box<dyn SessionManager<V>>>,
}
impl<V: Serialize + DeserializeOwned + Send + Sync> MultiSessionManager<V> {
pub fn new(
current: Box<dyn SessionManager<V>>,
previous: Vec<Box<dyn SessionManager<V>>>,
) -> Self {
Self { current, previous }
}
}
impl<V: Serialize + DeserializeOwned + Send + Sync> SessionManager<V> for MultiSessionManager<V> {
fn deserialize(&self, bytes: &[u8]) -> Result<Session<V>, SessionError> {
match self.current.deserialize(bytes) {
ok @ Ok(_) => return ok,
Err(_) => {
for manager in self.previous.iter() {
match manager.deserialize(bytes) {
ok @ Ok(_) => return ok,
Err(_) => (),
}
}
}
}
Err(SessionError::ValidationError)
}
fn serialize(&self, session: &Session<V>) -> Result<Vec<u8>, SessionError> {
self.current.serialize(session)
}
fn is_encrypted(&self) -> bool {
self.current.is_encrypted()
}
}
#[cfg(test)]
mod tests {
const KEY_1: [u8; 32] = *b"01234567012345670123456701234567";
const KEY_2: [u8; 32] = *b"76543210765432107654321076543210";
macro_rules! test_cases {
($strct: ident, $md: ident) => {
mod $md {
use super::KEY_1;
use serde::{Deserialize, Serialize};
use $crate::error::SessionError;
use $crate::session::{$strct, Session, SessionManager};
#[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)]
struct Data {
string: String,
}
#[test]
fn serde_happy_path() {
let manager = $strct::from_key(KEY_1);
let data = Data {
string: "boots and cats".to_string(),
};
let session = Session {
expires: None,
value: Some(data.clone()),
};
let bytes = manager.serialize(&session).expect("couldn't serialize");
let parsed_session = manager.deserialize(&bytes).expect("couldn't deserialize");
assert_eq!(parsed_session.value, Some(data));
}
#[test]
fn serde_bad_data_end() {
let manager = $strct::from_key(KEY_1);
let data = Data {
string: "boots and cats".to_string(),
};
let session = Session {
expires: None,
value: Some(data.clone()),
};
let mut bytes = manager.serialize(&session).expect("couldn't serialize");
let len = bytes.len();
bytes[len - 1] ^= 0x01;
let deserialized: Result<Session<Data>, SessionError> =
manager.deserialize(&bytes);
assert!(deserialized.is_err());
}
#[test]
fn serde_bad_data_start() {
let manager = $strct::from_key(KEY_1);
let data = Data {
string: "boots and cats".to_string(),
};
let session = Session {
expires: None,
value: Some(data.clone()),
};
let mut bytes = manager.serialize(&session).expect("couldn't serialize");
bytes[0] ^= 0x01;
let deserialized: Result<Session<Data>, SessionError> =
manager.deserialize(&bytes);
assert!(deserialized.is_err());
}
}
};
}
test_cases!(AesGcmSessionManager, aesgcm);
test_cases!(ChaCha20Poly1305SessionManager, chacha20poly1305);
mod multi {
macro_rules! test_cases {
($strct1: ident, $strct2: ident, $name: ident) => {
mod $name {
use super::super::{KEY_1, KEY_2};
use $crate::session::*;
#[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)]
struct Data {
string: String,
}
#[test]
fn no_previous() {
let manager = $strct1::from_key(KEY_1);
let mut sessions = vec![];
let data = Data {
string: "boots and cats".to_string(),
};
let session = Session {
expires: None,
value: Some(data.clone()),
};
let bytes = manager.serialize(&session).expect("couldn't serialize");
sessions.push(bytes);
let multi = MultiSessionManager::new(Box::new(manager), vec![]);
let bytes = multi.serialize(&session).expect("couldn't serialize");
sessions.push(bytes);
for session in sessions.iter() {
let parsed_session =
multi.deserialize(session).expect("couldn't deserialize");
assert_eq!(parsed_session.value, Some(data.clone()));
}
}
#[test]
fn $name() {
let manager_1 = $strct1::from_key(KEY_1);
let manager_2 = $strct2::from_key(KEY_2);
let mut sessions = vec![];
let data = Data {
string: "boots and cats".to_string(),
};
let session = Session {
expires: None,
value: Some(data.clone()),
};
let bytes = manager_1.serialize(&session).expect("couldn't serialize");
sessions.push(bytes);
let bytes = manager_2.serialize(&session).expect("couldn't serialize");
sessions.push(bytes);
let multi = MultiSessionManager::new(
Box::new(manager_1),
vec![Box::new(manager_2)],
);
let bytes = multi.serialize(&session).expect("couldn't serialize");
sessions.push(bytes);
for session in sessions.iter() {
let parsed_session =
multi.deserialize(session).expect("couldn't deserialize");
assert_eq!(parsed_session.value, Some(data.clone()));
}
}
}
};
}
test_cases!(
AesGcmSessionManager,
AesGcmSessionManager,
aesgcm_then_aesgcm
);
test_cases!(
ChaCha20Poly1305SessionManager,
ChaCha20Poly1305SessionManager,
chacha20poly1305_then_chacha20poly1305
);
test_cases!(
ChaCha20Poly1305SessionManager,
AesGcmSessionManager,
chacha20poly1305_then_aesgcm
);
test_cases!(
AesGcmSessionManager,
ChaCha20Poly1305SessionManager,
aesgcm_then_chacha20poly1305
);
}
}