use base64::Engine;
use hmac::{Hmac, Mac};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use subtle::ConstantTimeEq;
pub const COOKIE_NAME: &str = "rustango_op_session";
pub const SESSION_TTL_SECS: i64 = 7 * 24 * 60 * 60;
#[derive(Debug, thiserror::Error)]
pub enum SessionError {
#[error("session cookie is malformed")]
Malformed,
#[error("session signature mismatch")]
BadSignature,
#[error("session expired")]
Expired,
#[error("session is bound to a different tenant")]
WrongTenant,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub struct SessionPayload {
pub oid: i64,
pub exp: i64,
}
impl SessionPayload {
#[must_use]
pub fn new(operator_id: i64, ttl_secs: i64) -> Self {
let exp = chrono::Utc::now().timestamp() + ttl_secs;
Self {
oid: operator_id,
exp,
}
}
fn is_expired(&self) -> bool {
chrono::Utc::now().timestamp() >= self.exp
}
}
#[derive(Clone)]
pub struct SessionSecret(Vec<u8>);
impl SessionSecret {
#[must_use]
pub fn from_env_or_random() -> Self {
if let Ok(raw) = std::env::var("RUSTANGO_SESSION_SECRET") {
match base64::engine::general_purpose::STANDARD.decode(raw.trim()) {
Ok(bytes) if bytes.len() >= 32 => return Self(bytes),
Ok(bytes) => tracing::warn!(
target: "crate::tenancy",
actual_len = bytes.len(),
"RUSTANGO_SESSION_SECRET decoded to fewer than 32 bytes — falling back to random",
),
Err(e) => tracing::warn!(
target: "crate::tenancy",
error = %e,
"RUSTANGO_SESSION_SECRET is not valid base64 — falling back to random",
),
}
}
tracing::warn!(
target: "crate::tenancy",
"RUSTANGO_SESSION_SECRET not set — generating random key (sessions \
will not survive server restarts; set the env var for production)",
);
let mut buf = vec![0u8; 32];
rand::thread_rng().fill(&mut buf[..]);
Self(buf)
}
#[must_use]
pub fn from_bytes(bytes: Vec<u8>) -> Self {
Self(bytes)
}
pub(crate) fn key(&self) -> &[u8] {
&self.0
}
}
#[must_use]
pub fn encode(secret: &SessionSecret, payload: &SessionPayload) -> String {
let json = serde_json::to_vec(payload).expect("payload serializes");
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json);
let sig = sign(secret, payload_b64.as_bytes());
let sig_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sig);
format!("{payload_b64}.{sig_b64}")
}
pub fn decode(secret: &SessionSecret, value: &str) -> Result<SessionPayload, SessionError> {
let (payload_b64, sig_b64) = value.split_once('.').ok_or(SessionError::Malformed)?;
let expected = sign(secret, payload_b64.as_bytes());
let provided = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(sig_b64)
.map_err(|_| SessionError::Malformed)?;
if expected.ct_eq(&provided[..]).unwrap_u8() == 0 {
return Err(SessionError::BadSignature);
}
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(payload_b64)
.map_err(|_| SessionError::Malformed)?;
let payload: SessionPayload =
serde_json::from_slice(&payload_bytes).map_err(|_| SessionError::Malformed)?;
if payload.is_expired() {
return Err(SessionError::Expired);
}
Ok(payload)
}
pub(crate) fn sign(secret: &SessionSecret, msg: &[u8]) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_from_slice(secret.key())
.expect("HMAC accepts any key length");
mac.update(msg);
let bytes = mac.finalize().into_bytes();
let mut out = [0u8; 32];
out.copy_from_slice(&bytes[..32]);
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_valid_payload() {
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let payload = SessionPayload::new(42, 3600);
let cookie = encode(&secret, &payload);
let back = decode(&secret, &cookie).unwrap();
assert_eq!(back, payload);
}
#[test]
fn rejects_tampered_payload() {
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let payload = SessionPayload::new(42, 3600);
let cookie = encode(&secret, &payload);
let (_, sig) = cookie.split_once('.').unwrap();
let evil_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(br#"{"oid":999,"exp":9999999999}"#);
let tampered = format!("{evil_payload}.{sig}");
let err = decode(&secret, &tampered).unwrap_err();
assert!(matches!(err, SessionError::BadSignature));
}
#[test]
fn rejects_wrong_secret() {
let s1 = SessionSecret::from_bytes(b"first-test-secret-thirty-2-bytes".to_vec());
let s2 = SessionSecret::from_bytes(b"second-test-secret-thirty2-bytes".to_vec());
let cookie = encode(&s1, &SessionPayload::new(1, 3600));
let err = decode(&s2, &cookie).unwrap_err();
assert!(matches!(err, SessionError::BadSignature));
}
#[test]
fn rejects_expired() {
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let payload = SessionPayload::new(1, -10);
let cookie = encode(&secret, &payload);
let err = decode(&secret, &cookie).unwrap_err();
assert!(matches!(err, SessionError::Expired));
}
#[test]
fn rejects_malformed_no_dot() {
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let err = decode(&secret, "not-a-cookie").unwrap_err();
assert!(matches!(err, SessionError::Malformed));
}
}