use cookie::Key;
use http::StatusCode;
use http::request::Parts;
use crate::extractors::FromRequest;
use crate::extractors::FromRequestParts;
use crate::responder::Responder;
use crate::types::Request;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum KeyContext {
Signing,
Encryption,
Csrf,
Session,
Custom(String),
}
impl KeyContext {
pub fn as_bytes(&self) -> &[u8] {
match self {
KeyContext::Signing => b"cookie-signing",
KeyContext::Encryption => b"cookie-encryption",
KeyContext::Csrf => b"csrf-protection",
KeyContext::Session => b"session-management",
KeyContext::Custom(purpose) => purpose.as_bytes(),
}
}
}
#[derive(Debug, Clone)]
pub struct KeyExpansionConfig {
pub master_key: Key,
pub app_info: Vec<u8>,
pub key_length: usize,
}
impl KeyExpansionConfig {
pub fn new(master_key: Key, app_info: impl Into<Vec<u8>>) -> Self {
Self {
master_key,
app_info: app_info.into(),
key_length: 32, }
}
pub fn with_key_length(mut self, length: usize) -> Self {
self.key_length = length;
self
}
}
pub struct CookieKeyExpansion {
config: KeyExpansionConfig,
}
#[derive(Debug)]
pub enum CookieKeyExpansionError {
MissingConfig,
InvalidMasterKey,
DerivationFailed(String),
InvalidKeyLength,
UnsupportedAlgorithm,
}
impl Responder for CookieKeyExpansionError {
fn into_response(self) -> crate::types::Response {
match self {
CookieKeyExpansionError::MissingConfig => (
StatusCode::INTERNAL_SERVER_ERROR,
"Key expansion configuration not found in request extensions",
)
.into_response(),
CookieKeyExpansionError::InvalidMasterKey => (
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid master key for key derivation",
)
.into_response(),
CookieKeyExpansionError::DerivationFailed(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Key derivation failed: {err}"),
)
.into_response(),
CookieKeyExpansionError::InvalidKeyLength => (
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid key length specified for derivation",
)
.into_response(),
CookieKeyExpansionError::UnsupportedAlgorithm => (
StatusCode::INTERNAL_SERVER_ERROR,
"Unsupported key derivation algorithm",
)
.into_response(),
}
}
}
impl CookieKeyExpansion {
pub fn new(config: KeyExpansionConfig) -> Self {
Self { config }
}
pub fn derive_key(&self, context: KeyContext) -> Result<Key, CookieKeyExpansionError> {
self.derive_key_with_info(context, &[])
}
pub fn derive_key_with_info(
&self,
context: KeyContext,
additional_info: &[u8],
) -> Result<Key, CookieKeyExpansionError> {
if self.config.key_length < 16 || self.config.key_length > 64 {
return Err(CookieKeyExpansionError::InvalidKeyLength);
}
let mut info = Vec::new();
info.extend_from_slice(context.as_bytes());
info.push(0x00); info.extend_from_slice(&self.config.app_info);
if !additional_info.is_empty() {
info.push(0x00); info.extend_from_slice(additional_info);
}
let derived_key = self.hkdf_expand(&info)?;
Key::try_from(derived_key.as_slice())
.map_err(|e| CookieKeyExpansionError::DerivationFailed(e.to_string()))
}
pub fn derive_keys(
&self,
contexts: &[KeyContext],
) -> Result<Vec<(KeyContext, Key)>, CookieKeyExpansionError> {
contexts
.iter()
.map(|context| {
let key = self.derive_key(context.clone())?;
Ok((context.clone(), key))
})
.collect()
}
pub fn signing_key(&self) -> Result<Key, CookieKeyExpansionError> {
self.derive_key(KeyContext::Signing)
}
pub fn encryption_key(&self) -> Result<Key, CookieKeyExpansionError> {
self.derive_key(KeyContext::Encryption)
}
pub fn csrf_key(&self) -> Result<Key, CookieKeyExpansionError> {
self.derive_key(KeyContext::Csrf)
}
pub fn session_key(&self) -> Result<Key, CookieKeyExpansionError> {
self.derive_key(KeyContext::Session)
}
fn hkdf_expand(&self, info: &[u8]) -> Result<Vec<u8>, CookieKeyExpansionError> {
use std::collections::hash_map::DefaultHasher;
use std::hash::Hash;
use std::hash::Hasher;
let mut hasher = DefaultHasher::new();
self.config.master_key.master().hash(&mut hasher);
info.hash(&mut hasher);
let hash_result = hasher.finish();
let mut derived_key = Vec::with_capacity(self.config.key_length);
let hash_bytes = hash_result.to_le_bytes();
for i in 0..self.config.key_length {
derived_key.push(hash_bytes[i % hash_bytes.len()]);
}
for (i, &byte) in self.config.master_key.master().iter().enumerate() {
if i < derived_key.len() {
derived_key[i] ^= byte;
}
}
Ok(derived_key)
}
pub fn config(&self) -> &KeyExpansionConfig {
&self.config
}
fn extract_from_request(req: &Request) -> Result<Self, CookieKeyExpansionError> {
let config = req
.extensions()
.get::<KeyExpansionConfig>()
.ok_or(CookieKeyExpansionError::MissingConfig)?;
Ok(Self::new(config.clone()))
}
fn extract_from_parts(parts: &Parts) -> Result<Self, CookieKeyExpansionError> {
let config = parts
.extensions
.get::<KeyExpansionConfig>()
.ok_or(CookieKeyExpansionError::MissingConfig)?;
Ok(Self::new(config.clone()))
}
}
impl<'a> FromRequest<'a> for CookieKeyExpansion {
type Error = CookieKeyExpansionError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
futures_util::future::ready(Self::extract_from_request(req))
}
}
impl<'a> FromRequestParts<'a> for CookieKeyExpansion {
type Error = CookieKeyExpansionError;
fn from_request_parts(
parts: &'a mut Parts,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
futures_util::future::ready(Self::extract_from_parts(parts))
}
}