const PREFIX_SIZE: usize = 2;
const NONCE_SIZE: usize = 12;
const AUTH_TAG_SIZE: usize = 16;
const MINIMUM_ENCRYPTED_SIZE: usize = PREFIX_SIZE + NONCE_SIZE + AUTH_TAG_SIZE;
const MINIMUM_COOKIE_VALUE_SIZE: usize = MINIMUM_ENCRYPTED_SIZE * 3 / 4;
static NONCE_COUNTER: std::sync::OnceLock<Option<[std::sync::atomic::AtomicU32; 3]>> =
std::sync::OnceLock::new();
#[derive(Clone)]
pub struct SessionStoreKey {
keys: Vec<(aes_gcm_siv::Aes256GcmSiv, u8)>,
}
#[derive(Debug)]
pub enum KeyError {
GetRandomError,
SecretTooShort,
DuplicatedKeyId,
}
#[derive(Debug)]
pub enum EncodeError {
PayloadTooLarge,
MessagePackEncodeError(rmp_serde::encode::Error),
}
#[derive(Debug)]
pub enum DecodeError {
CookieTooShort,
NoKey,
Base64DecodeError(base64::DecodeError),
DecryptionError(aes_gcm_siv::Error),
MessagePackDecodeError(rmp_serde::decode::Error),
}
impl SessionStoreKey {
pub fn new(key_id: u8, secret: &str) -> Result<Self, KeyError> {
use std::sync::atomic::AtomicU32;
let nonce = NONCE_COUNTER.get_or_init(|| {
let r0 = AtomicU32::new(getrandom::u32().ok()?);
let r1 = AtomicU32::new(getrandom::u32().ok()?);
let r2 = AtomicU32::new(getrandom::u32().ok()?);
Some([r0, r1, r2])
});
if nonce.is_none() {
return Err(KeyError::GetRandomError);
}
if secret.len() < 40 {
return Err(KeyError::SecretTooShort);
}
let aes_key = Self::derive_key(secret)?;
Ok(Self {
keys: vec![(aes_key, key_id)],
})
}
pub fn add_key(self, key_id: u8, secret: &str) -> Result<Self, KeyError> {
if self.decrypt_key_by_id(key_id).is_some() {
return Err(KeyError::DuplicatedKeyId);
}
let mut updated_self = self;
updated_self.keys.push((Self::derive_key(secret)?, key_id));
Ok(updated_self)
}
pub fn encrypt<'a, T>(
&self,
name: &'a str,
payload: &T,
payload_ver: u8,
) -> Result<cookie::CookieBuilder<'a>, EncodeError>
where
T: serde::Serialize,
{
use aes_gcm_siv::{AeadInPlace, Nonce};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let mut associated_data = Vec::<u8>::with_capacity(name.len() + 2);
associated_data.push(payload_ver);
associated_data.push(self.encrypt_key_id());
associated_data.extend_from_slice(name.as_bytes());
let mut message = Vec::with_capacity(3072); message.push(payload_ver);
message.push(self.encrypt_key_id());
message.extend_from_slice(&self.nonce());
let mut serializer = rmp_serde::Serializer::new(&mut message)
.with_bytes(rmp_serde::config::BytesMode::ForceAll)
.with_struct_map();
payload
.serialize(&mut serializer)
.map_err(|e| EncodeError::MessagePackEncodeError(e))?;
let (header_nonce, msg_pack) = message.split_at_mut(PREFIX_SIZE + NONCE_SIZE);
let tag = self
.encrypt_key()
.encrypt_in_place_detached(
Nonce::from_slice(&header_nonce[PREFIX_SIZE..]),
&associated_data,
msg_pack,
)
.map_err(|_| EncodeError::PayloadTooLarge)?;
message.extend_from_slice(tag.as_slice());
let cookie_value = URL_SAFE_NO_PAD.encode(&message);
let builder = cookie::Cookie::build((name, cookie_value))
.http_only(true)
.secure(true);
if name.starts_with("__Host-") {
Ok(builder.path("/"))
} else {
Ok(builder)
}
}
pub fn decrypt<T>(&self, cookie: &cookie::Cookie) -> Result<T, DecodeError>
where
T: serde::de::DeserializeOwned,
{
use aes_gcm_siv::{AeadInPlace, Nonce, Tag};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let (name, value) = cookie.name_value();
if value.len() < MINIMUM_COOKIE_VALUE_SIZE {
return Err(DecodeError::CookieTooShort);
}
let mut encrypted = URL_SAFE_NO_PAD
.decode(value)
.map_err(|e| DecodeError::Base64DecodeError(e))?;
let key_id = encrypted[1];
let mut associated_data = Vec::<u8>::with_capacity(name.len() + 2);
associated_data.push(encrypted[0]); associated_data.push(encrypted[1]); associated_data.extend_from_slice(name.as_bytes());
let (header_nonce, msg_pack) = encrypted.split_at_mut(PREFIX_SIZE + NONCE_SIZE);
let (msg_pack, tag) = msg_pack.split_at_mut(msg_pack.len() - AUTH_TAG_SIZE);
let key = self.decrypt_key_by_id(key_id).ok_or(DecodeError::NoKey)?;
key.decrypt_in_place_detached(
Nonce::from_slice(&header_nonce[PREFIX_SIZE..]),
&associated_data,
msg_pack,
Tag::from_slice(&tag),
)
.map_err(|e| DecodeError::DecryptionError(e))?;
let payload = rmp_serde::from_slice::<T>(&msg_pack)
.map_err(|e| DecodeError::MessagePackDecodeError(e))?;
Ok(payload)
}
pub fn payload_ver(cookie: &cookie::Cookie) -> Option<u8> {
let value = cookie.value();
if let Some(leading_4char) = value.get(0..4) {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
if let Ok(leading_3bytes) = URL_SAFE_NO_PAD.decode(leading_4char) {
Some(leading_3bytes[0])
} else {
None
}
} else {
None
}
}
fn nonce(&self) -> [u8; 12] {
use std::sync::atomic::Ordering::Relaxed;
use std::u32::MAX;
let nonce_counter = NONCE_COUNTER.get().unwrap().as_ref().unwrap();
let u0 = nonce_counter[0].fetch_add(1, Relaxed);
let carry = if u0 == MAX { 1 } else { 0 };
let u1 = nonce_counter[1].fetch_add(carry, Relaxed);
let carry = if u1 == MAX { 1 } else { 0 };
let u2 = nonce_counter[2].fetch_add(carry, Relaxed);
let mut nonce = [0u8; 12];
nonce[0..4].copy_from_slice(&u0.to_le_bytes());
nonce[4..8].copy_from_slice(&u1.to_le_bytes());
nonce[8..12].copy_from_slice(&u2.to_le_bytes());
nonce
}
fn encrypt_key_id(&self) -> u8 {
self.keys[0].1
}
fn encrypt_key<'a>(&'a self) -> &'a aes_gcm_siv::Aes256GcmSiv {
&self.keys[0].0
}
fn decrypt_key_by_id<'a>(&'a self, key_id: u8) -> Option<&'a aes_gcm_siv::Aes256GcmSiv> {
let (key, _kid) = self.keys.iter().find(|(_key, kid)| *kid == key_id)?;
Some(key)
}
fn derive_key(secret: &str) -> Result<aes_gcm_siv::Aes256GcmSiv, KeyError> {
use aes_gcm_siv::KeyInit;
const HKDF_SALT: &[u8] = b"SessionStoreKey AES-GCM-SIV key";
if secret.len() < 40 {
return Err(KeyError::SecretTooShort);
}
let (key, _) = hkdf::Hkdf::<sha2::Sha256>::extract(Some(HKDF_SALT), secret.as_bytes());
let aes_key = aes_gcm_siv::Aes256GcmSiv::new_from_slice(&key).unwrap();
Ok(aes_key)
}
}
impl From<getrandom::Error> for KeyError {
fn from(_e: getrandom::Error) -> Self {
KeyError::GetRandomError
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, PartialEq, serde::Deserialize, serde::Serialize)]
struct Session {
user_id: i32,
name: String,
}
#[test]
fn key_generation() {
SessionStoreKey::new(0, "0123456789012345678901234567890123456789").unwrap();
SessionStoreKey::new(0, "0123456789012345678901234567890123456789")
.unwrap()
.add_key(1, "1234567890123456789012345678901234567890")
.unwrap();
}
#[test]
fn too_short_secret() {
if let Err(KeyError::SecretTooShort) =
SessionStoreKey::new(0, "012345678901234567890123456789012345678")
{
} else {
panic!("Result != SecretTooShort");
}
}
#[test]
fn duplicated_key_id() {
let duplicated = SessionStoreKey::new(1, "0123456789012345678901234567890123456789")
.unwrap()
.add_key(1, "1234567890123456789012345678901234567890");
if let Err(KeyError::DuplicatedKeyId) = duplicated {
} else {
panic!("Result != DuplicatedKeyId");
}
}
#[test]
fn encrypt_decrypt() {
let alice = Session {
user_id: 1,
name: "Alice".to_string(),
};
let key = SessionStoreKey::new(0, "0123456789012345678901234567890123456789").unwrap();
let cookie = key.encrypt("session", &alice, 5).unwrap().build();
assert_eq!(SessionStoreKey::payload_ver(&cookie), Some(5));
let decrypted_session = key.decrypt::<Session>(&cookie).unwrap();
assert_eq!(decrypted_session, alice);
let key = SessionStoreKey::new(1, "1234567890123456789012345678901234567890")
.unwrap()
.add_key(0, "0123456789012345678901234567890123456789")
.unwrap();
let alt_decrypted = key.decrypt::<Session>(&cookie).unwrap();
assert_eq!(alt_decrypted, alice);
}
#[test]
fn modified() {
use cookie::Cookie;
let key = SessionStoreKey::new(0, "0123456789012345678901234567890123456789")
.unwrap()
.add_key(1, "0123456789012345678901234567890123456789")
.unwrap();
let plain = "PlainText".to_string();
let cookie = key.encrypt("session", &plain, 5).unwrap().build();
let cookie_no_modify = modify_cookie_value(&cookie, |_| {});
assert!(key.decrypt::<String>(&cookie_no_modify).is_ok());
let value = cookie.value();
let cookie_mod_name = Cookie::new("SESSION", value);
assert!(key.decrypt::<String>(&cookie_mod_name).is_err());
let cookie_mod_ver = modify_cookie_value(&cookie, |payload| {
payload[0] = 6; });
assert!(key.decrypt::<String>(&cookie_mod_ver).is_err());
let cookie_mod_keyid = modify_cookie_value(&cookie, |payload| {
payload[1] = 1; });
assert!(key.decrypt::<String>(&cookie_mod_keyid).is_err());
let cookie_mod_nonce = modify_cookie_value(&cookie, |payload| {
payload[3] ^= 0x04;
});
assert!(key.decrypt::<String>(&cookie_mod_nonce).is_err());
let cookie_mod_msg = modify_cookie_value(&cookie, |payload| {
payload[15] ^= 0x20;
});
assert!(key.decrypt::<String>(&cookie_mod_msg).is_err());
let cookie_mod_tag = modify_cookie_value(&cookie, |payload| {
let tag_pos = payload.len() - 1;
payload[tag_pos] ^= 0x80;
});
assert!(key.decrypt::<String>(&cookie_mod_tag).is_err());
}
fn modify_cookie_value<F>(cookie: &cookie::Cookie, f: F) -> cookie::Cookie<'static>
where
F: FnOnce(&mut Vec<u8>),
{
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let (name, value) = cookie.name_value();
let mut encrypted = URL_SAFE_NO_PAD.decode(value).unwrap();
f(&mut encrypted);
let base64enc = URL_SAFE_NO_PAD.encode(&encrypted);
cookie::Cookie::new(name.to_string(), base64enc)
}
}