use base64::Engine;
use hmac::{Hmac, Mac};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use subtle::ConstantTimeEq;
pub const COOKIE_NAME: &str = "rustango_op_session";
pub const SESSION_TTL_SECS: i64 = 7 * 24 * 60 * 60;
#[derive(Debug, thiserror::Error)]
pub enum SessionError {
#[error("session cookie is malformed")]
Malformed,
#[error("session signature mismatch")]
BadSignature,
#[error("session expired")]
Expired,
#[error("session is bound to a different tenant")]
WrongTenant,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub struct SessionPayload {
pub oid: i64,
pub exp: i64,
#[serde(default)]
pub iat: i64,
}
impl SessionPayload {
#[must_use]
pub fn new(operator_id: i64, ttl_secs: i64) -> Self {
let now = chrono::Utc::now().timestamp();
Self {
oid: operator_id,
exp: now + ttl_secs,
iat: now,
}
}
fn is_expired(&self) -> bool {
chrono::Utc::now().timestamp() >= self.exp
}
}
#[derive(Clone)]
pub struct SessionSecret(Vec<u8>);
#[derive(Debug)]
pub enum SessionSecretError {
BadBase64 { cause: String },
TooShort { actual: usize },
}
impl core::fmt::Display for SessionSecretError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::BadBase64 { cause } => write!(
f,
"RUSTANGO_SESSION_SECRET is not valid base64: {cause} \
(generate one with: openssl rand -base64 32)"
),
Self::TooShort { actual } => write!(
f,
"RUSTANGO_SESSION_SECRET decoded to {actual} bytes; need at least 32 \
(generate one with: openssl rand -base64 32)"
),
}
}
}
impl std::error::Error for SessionSecretError {}
impl SessionSecret {
#[must_use]
pub fn from_env_or_random() -> Self {
if let Ok(raw) = std::env::var("RUSTANGO_SESSION_SECRET") {
match base64::engine::general_purpose::STANDARD.decode(raw.trim()) {
Ok(bytes) if bytes.len() >= 32 => return Self(bytes),
Ok(bytes) => {
tracing::warn!(
target: "crate::tenancy",
actual_len = bytes.len(),
"RUSTANGO_SESSION_SECRET decoded to fewer than 32 bytes — falling back to random",
);
eprintln!(
"\x1b[33;1mwarning:\x1b[0m RUSTANGO_SESSION_SECRET is set but \
decoded to {} bytes (need ≥ 32). Using a random key. \
Sessions will NOT survive a server restart. \
Generate one with: \
openssl rand -base64 32",
bytes.len()
);
}
Err(e) => {
tracing::warn!(
target: "crate::tenancy",
error = %e,
"RUSTANGO_SESSION_SECRET is not valid base64 — falling back to random",
);
eprintln!(
"\x1b[33;1mwarning:\x1b[0m RUSTANGO_SESSION_SECRET is set but \
is not valid base64 ({}). Using a random key. \
Sessions will NOT survive a server restart. \
Generate one with: \
openssl rand -base64 32",
e
);
}
}
} else {
tracing::warn!(
target: "crate::tenancy",
"RUSTANGO_SESSION_SECRET not set — generating random key (sessions \
will not survive server restarts; set the env var for production)",
);
}
let mut buf = vec![0u8; 32];
rand::thread_rng().fill(&mut buf[..]);
Self(buf)
}
#[must_use]
pub fn from_env_or_disk(disk_path: &std::path::Path) -> Self {
if let Ok(raw) = std::env::var("RUSTANGO_SESSION_SECRET") {
if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(raw.trim()) {
if bytes.len() >= 32 {
return Self(bytes);
}
}
}
if let Ok(bytes) = std::fs::read(disk_path) {
if bytes.len() >= 32 {
tracing::debug!(
target: "crate::tenancy::session",
path = %disk_path.display(),
"loaded persistent session secret from disk",
);
return Self(bytes);
}
}
let mut buf = vec![0u8; 32];
rand::thread_rng().fill(&mut buf[..]);
if let Some(parent) = disk_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let tmp_path = disk_path.with_extension("tmp");
match std::fs::write(&tmp_path, &buf).and_then(|_| std::fs::rename(&tmp_path, disk_path)) {
Ok(()) => {
tracing::info!(
target: "crate::tenancy::session",
path = %disk_path.display(),
"persisted new session secret to disk (dev fallback)",
);
}
Err(e) => {
tracing::warn!(
target: "crate::tenancy::session",
path = %disk_path.display(),
error = %e,
"could not persist session secret to disk — using ephemeral random key (sessions will not survive restart)",
);
let _ = std::fs::remove_file(&tmp_path);
}
}
Self(buf)
}
pub fn try_from_env() -> Result<Self, SessionSecretError> {
if let Ok(raw) = std::env::var("RUSTANGO_SESSION_SECRET") {
return match base64::engine::general_purpose::STANDARD.decode(raw.trim()) {
Ok(bytes) if bytes.len() >= 32 => Ok(Self(bytes)),
Ok(bytes) => Err(SessionSecretError::TooShort {
actual: bytes.len(),
}),
Err(e) => Err(SessionSecretError::BadBase64 {
cause: e.to_string(),
}),
};
}
tracing::warn!(
target: "crate::tenancy",
"RUSTANGO_SESSION_SECRET not set — generating random key (sessions \
will not survive server restarts; set the env var for production)",
);
let mut buf = vec![0u8; 32];
rand::thread_rng().fill(&mut buf[..]);
Ok(Self(buf))
}
#[must_use]
pub fn from_bytes(bytes: Vec<u8>) -> Self {
Self(bytes)
}
pub(crate) fn key(&self) -> &[u8] {
&self.0
}
}
#[must_use]
pub fn encode(secret: &SessionSecret, payload: &SessionPayload) -> String {
let json = serde_json::to_vec(payload).expect("payload serializes");
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json);
let sig = sign(secret, payload_b64.as_bytes());
let sig_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sig);
format!("{payload_b64}.{sig_b64}")
}
pub fn decode(secret: &SessionSecret, value: &str) -> Result<SessionPayload, SessionError> {
let (payload_b64, sig_b64) = value.split_once('.').ok_or(SessionError::Malformed)?;
let expected = sign(secret, payload_b64.as_bytes());
let provided = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(sig_b64)
.map_err(|_| SessionError::Malformed)?;
if expected.ct_eq(&provided[..]).unwrap_u8() == 0 {
return Err(SessionError::BadSignature);
}
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(payload_b64)
.map_err(|_| SessionError::Malformed)?;
let payload: SessionPayload =
serde_json::from_slice(&payload_bytes).map_err(|_| SessionError::Malformed)?;
if payload.is_expired() {
return Err(SessionError::Expired);
}
Ok(payload)
}
pub(crate) fn sign(secret: &SessionSecret, msg: &[u8]) -> [u8; 32] {
let mut mac =
Hmac::<Sha256>::new_from_slice(secret.key()).expect("HMAC accepts any key length");
mac.update(msg);
let bytes = mac.finalize().into_bytes();
let mut out = [0u8; 32];
out.copy_from_slice(&bytes[..32]);
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_valid_payload() {
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let payload = SessionPayload::new(42, 3600);
let cookie = encode(&secret, &payload);
let back = decode(&secret, &cookie).unwrap();
assert_eq!(back, payload);
}
#[test]
fn rejects_tampered_payload() {
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let payload = SessionPayload::new(42, 3600);
let cookie = encode(&secret, &payload);
let (_, sig) = cookie.split_once('.').unwrap();
let evil_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(br#"{"oid":999,"exp":9999999999}"#);
let tampered = format!("{evil_payload}.{sig}");
let err = decode(&secret, &tampered).unwrap_err();
assert!(matches!(err, SessionError::BadSignature));
}
#[test]
fn rejects_wrong_secret() {
let s1 = SessionSecret::from_bytes(b"first-test-secret-thirty-2-bytes".to_vec());
let s2 = SessionSecret::from_bytes(b"second-test-secret-thirty2-bytes".to_vec());
let cookie = encode(&s1, &SessionPayload::new(1, 3600));
let err = decode(&s2, &cookie).unwrap_err();
assert!(matches!(err, SessionError::BadSignature));
}
#[test]
fn rejects_expired() {
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let payload = SessionPayload::new(1, -10);
let cookie = encode(&secret, &payload);
let err = decode(&secret, &cookie).unwrap_err();
assert!(matches!(err, SessionError::Expired));
}
#[test]
fn rejects_malformed_no_dot() {
let secret = SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec());
let err = decode(&secret, "not-a-cookie").unwrap_err();
assert!(matches!(err, SessionError::Malformed));
}
}