use std::num::NonZeroUsize;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::anyhow;
use base64::Engine as _;
use p256::ecdsa::signature::{DigestSigner, DigestVerifier};
use p256::ecdsa::{Signature, SigningKey, VerifyingKey};
use p256::pkcs8::{DecodePrivateKey, EncodePrivateKey, LineEnding};
use rand_core::OsRng;
use sha2::{Digest, Sha256};
use crate::auth::{Authenticator, Identity};
#[derive(Debug)]
pub enum JwtResult {
Valid(Identity),
Invalid,
Expired,
NotMine,
}
const B64: base64::engine::general_purpose::GeneralPurpose =
base64::engine::general_purpose::URL_SAFE_NO_PAD;
const JWT_CACHE_CAPACITY: usize = 1024;
struct CachedToken {
identity: Identity,
exp: u64,
}
pub struct JwtConfig {
pub cookie_name: String,
pub validity_secs: u64,
}
pub struct JwtManager {
signing_key: SigningKey,
verifying_key: VerifyingKey,
pub kid: String,
config: JwtConfig,
pub inner: Option<Arc<dyn Authenticator>>,
cache: Mutex<lru::LruCache<[u8; 32], CachedToken>>,
}
impl JwtManager {
pub fn load_or_generate(
state_dir: &Path,
config: JwtConfig,
inner: Option<Arc<dyn Authenticator>>,
) -> anyhow::Result<Self> {
let key_dir = state_dir.join("jwt");
std::fs::create_dir_all(&key_dir).map_err(|e| {
anyhow!("creating jwt key dir {}: {e}", key_dir.display())
})?;
let key_path = key_dir.join("ec-key.pem");
let signing_key = if key_path.exists() {
let pem = std::fs::read_to_string(&key_path).map_err(|e| {
anyhow!("reading jwt key {}: {e}", key_path.display())
})?;
SigningKey::from_pkcs8_pem(&pem).map_err(|e| {
anyhow!("parsing jwt key {}: {e}", key_path.display())
})?
} else {
let key = SigningKey::random(&mut OsRng);
let pem = key
.to_pkcs8_pem(LineEnding::LF)
.map_err(|e| anyhow!("encoding jwt key: {e}"))?;
write_private_file(&key_path, pem.as_bytes()).map_err(|e| {
anyhow!("writing jwt key {}: {e}", key_path.display())
})?;
tracing::info!(
path = %key_path.display(),
"generated new EC key"
);
key
};
let verifying_key = *signing_key.verifying_key();
let kid = compute_kid(&verifying_key);
let cache = Mutex::new(lru::LruCache::new(
NonZeroUsize::new(JWT_CACHE_CAPACITY).unwrap(),
));
Ok(Self {
signing_key,
verifying_key,
kid,
config,
inner,
cache,
})
}
pub fn validate(&self, headers: &hyper::HeaderMap) -> Option<JwtResult> {
let token = extract_token(headers, &self.config.cookie_name)?;
match self.validate_token(&token) {
JwtResult::NotMine => None,
other => Some(other),
}
}
fn validate_token(&self, token: &str) -> JwtResult {
let key: [u8; 32] = Sha256::digest(token.as_bytes()).into();
let now = match now_secs() {
Some(t) => t,
None => return JwtResult::Invalid,
};
{
let mut cache = self.cache.lock().expect("jwt cache mutex");
if let Some(cached) = cache.get(&key) {
if cached.exp > now {
return JwtResult::Valid(cached.identity.clone());
}
cache.pop(&key);
return JwtResult::Expired;
}
}
let parts: Vec<&str> = token.splitn(3, '.').collect();
if parts.len() != 3 {
tracing::debug!("invalid token format");
return JwtResult::Invalid;
}
let (header_b64, payload_b64, sig_b64) = (parts[0], parts[1], parts[2]);
let header_bytes = match B64.decode(header_b64) {
Ok(b) => b,
Err(_) => return JwtResult::Invalid,
};
let header: serde_json::Value =
match serde_json::from_slice(&header_bytes) {
Ok(v) => v,
Err(_) => return JwtResult::Invalid,
};
if header.get("alg").and_then(|v| v.as_str()) != Some("ES256") {
tracing::debug!("unexpected algorithm in header");
return JwtResult::Invalid;
}
if header.get("kid").and_then(|v| v.as_str()) != Some(self.kid.as_str())
{
tracing::debug!("kid mismatch (not our token)");
return JwtResult::NotMine;
}
let sig_bytes = match B64.decode(sig_b64) {
Ok(b) => b,
Err(_) => return JwtResult::Invalid,
};
let sig = match Signature::from_slice(&sig_bytes) {
Ok(s) => s,
Err(_) => return JwtResult::Invalid,
};
let signed_input = format!("{header_b64}.{payload_b64}");
let digest = Sha256::new_with_prefix(signed_input.as_bytes());
if self.verifying_key.verify_digest(digest, &sig).is_err() {
tracing::debug!("signature verification failed");
return JwtResult::Invalid;
}
let payload_bytes = match B64.decode(payload_b64) {
Ok(b) => b,
Err(_) => return JwtResult::Invalid,
};
let payload: serde_json::Value =
match serde_json::from_slice(&payload_bytes) {
Ok(v) => v,
Err(_) => return JwtResult::Invalid,
};
let exp = match payload.get("exp").and_then(|v| v.as_u64()) {
Some(e) => e,
None => return JwtResult::Invalid,
};
if exp <= now {
tracing::debug!("token expired");
return JwtResult::Expired;
}
let username = match payload
.get("sub")
.and_then(|v| v.as_str())
.map(str::to_owned)
{
Some(u) => u,
None => return JwtResult::Invalid,
};
let groups: Vec<String> = payload
.get("groups")
.and_then(|g| g.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(str::to_owned))
.collect()
})
.unwrap_or_default();
let identity = Identity { username, groups };
self.cache.lock().expect("jwt cache mutex").put(
key,
CachedToken {
identity: identity.clone(),
exp,
},
);
JwtResult::Valid(identity)
}
pub fn issue(&self, identity: &Identity) -> anyhow::Result<String> {
let now = now_secs().ok_or_else(|| anyhow!("system clock error"))?;
let exp = now + self.config.validity_secs;
let header = serde_json::json!({
"alg": "ES256",
"typ": "JWT",
"kid": self.kid,
});
let payload = serde_json::json!({
"sub": identity.username,
"groups": identity.groups,
"iat": now,
"exp": exp,
"iss": "hypershunt",
});
let h = B64.encode(serde_json::to_string(&header)?);
let p = B64.encode(serde_json::to_string(&payload)?);
let signed_input = format!("{h}.{p}");
let digest = Sha256::new_with_prefix(signed_input.as_bytes());
let sig: Signature = self.signing_key.sign_digest(digest);
let s = B64.encode(sig.to_bytes());
Ok(format!("{signed_input}.{s}"))
}
pub fn make_set_cookie(
&self,
identity: &Identity,
is_tls: bool,
) -> anyhow::Result<String> {
let token = self.issue(identity)?;
let mut cookie = format!(
"{}={}; Path=/; HttpOnly; SameSite=Strict; Max-Age={}",
self.config.cookie_name, token, self.config.validity_secs,
);
if is_tls {
cookie.push_str("; Secure");
}
Ok(cookie)
}
pub fn jwks_json(&self) -> String {
let ep = self.verifying_key.to_encoded_point(false);
let x = B64.encode(ep.x().expect("uncompressed point has x"));
let y = B64.encode(ep.y().expect("uncompressed point has y"));
serde_json::to_string(&serde_json::json!({
"keys": [{
"kty": "EC",
"crv": "P-256",
"use": "sig",
"alg": "ES256",
"kid": self.kid,
"x": x,
"y": y,
}]
}))
.expect("JWKS serialization is infallible")
}
pub fn is_session_mode(&self) -> bool {
self.inner.is_some()
}
pub fn cookie_name(&self) -> &str {
&self.config.cookie_name
}
}
fn extract_token(
headers: &hyper::HeaderMap,
cookie_name: &str,
) -> Option<String> {
if let Some(cookie_hdr) = headers.get(hyper::header::COOKIE)
&& let Ok(cookie_str) = cookie_hdr.to_str()
{
let prefix = format!("{cookie_name}=");
for part in cookie_str.split(';') {
let part = part.trim();
if let Some(val) = part.strip_prefix(&prefix) {
return Some(val.to_owned());
}
}
}
if let Some(auth_hdr) = headers.get(hyper::header::AUTHORIZATION)
&& let Ok(s) = auth_hdr.to_str()
&& let Some(token) = s.strip_prefix("Bearer ")
{
return Some(token.to_owned());
}
None
}
fn compute_kid(key: &VerifyingKey) -> String {
let ep = key.to_encoded_point(false);
let x = B64.encode(ep.x().expect("uncompressed point has x"));
let y = B64.encode(ep.y().expect("uncompressed point has y"));
let thumbprint_input =
format!(r#"{{"crv":"P-256","kty":"EC","x":"{x}","y":"{y}"}}"#);
let hash = Sha256::digest(thumbprint_input.as_bytes());
B64.encode(hash)
}
fn now_secs() -> Option<u64> {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.ok()
.map(|d| d.as_secs())
}
fn write_private_file(path: &Path, data: &[u8]) -> std::io::Result<()> {
use std::io::Write;
#[cfg(unix)]
{
use std::fs::OpenOptions;
use std::os::unix::fs::OpenOptionsExt;
let mut f = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(path)?;
f.write_all(data)
}
#[cfg(not(unix))]
{
let mut f = std::fs::File::create(path)?;
f.write_all(data)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AnonymousAuthenticator;
use tempfile::TempDir;
fn test_manager(tmp: &TempDir) -> JwtManager {
JwtManager::load_or_generate(
tmp.path(),
JwtConfig {
cookie_name: "sess".to_owned(),
validity_secs: 300,
},
None,
)
.expect("manager creation")
}
fn identity(user: &str) -> Identity {
Identity {
username: user.to_owned(),
groups: vec!["admins".to_owned()],
}
}
#[test]
fn issue_then_validate_roundtrip() {
let tmp = TempDir::new().unwrap();
let mgr = test_manager(&tmp);
let id = identity("alice");
let token = mgr.issue(&id).expect("issue");
let mut hdrs = hyper::HeaderMap::new();
hdrs.insert(
hyper::header::COOKIE,
format!("sess={token}").parse().unwrap(),
);
let got = match mgr.validate(&hdrs).expect("validate") {
JwtResult::Valid(id) => id,
JwtResult::Invalid | JwtResult::Expired | JwtResult::NotMine => {
panic!("expected Valid")
}
};
assert_eq!(got.username, "alice");
assert_eq!(got.groups, vec!["admins"]);
}
#[test]
fn kid_mismatch_returns_none_from_validate() {
let tmp_a = TempDir::new().unwrap();
let mgr_a = test_manager(&tmp_a);
let alien_token = mgr_a.issue(&identity("alice")).unwrap();
let tmp_b = TempDir::new().unwrap();
let mgr_b = test_manager(&tmp_b);
let mut hdrs = hyper::HeaderMap::new();
hdrs.insert(
hyper::header::COOKIE,
format!("sess={alien_token}").parse().unwrap(),
);
assert!(
mgr_b.validate(&hdrs).is_none(),
"kid mismatch must be folded to None",
);
}
#[test]
fn bearer_token_is_accepted() {
let tmp = TempDir::new().unwrap();
let mgr = test_manager(&tmp);
let id = identity("bob");
let token = mgr.issue(&id).expect("issue");
let mut hdrs = hyper::HeaderMap::new();
hdrs.insert(
hyper::header::AUTHORIZATION,
format!("Bearer {token}").parse().unwrap(),
);
let got = match mgr.validate(&hdrs).expect("validate via bearer") {
JwtResult::Valid(id) => id,
JwtResult::Invalid | JwtResult::Expired | JwtResult::NotMine => {
panic!("expected Valid")
}
};
assert_eq!(got.username, "bob");
}
#[test]
fn expired_token_returns_none() {
let tmp = TempDir::new().unwrap();
let mgr = JwtManager::load_or_generate(
tmp.path(),
JwtConfig {
cookie_name: "sess".to_owned(),
validity_secs: 0,
},
None,
)
.unwrap();
let token = mgr.issue(&identity("carol")).expect("issue");
let mut hdrs = hyper::HeaderMap::new();
hdrs.insert(
hyper::header::COOKIE,
format!("sess={token}").parse().unwrap(),
);
assert!(
matches!(mgr.validate(&hdrs), Some(JwtResult::Expired)),
"expired must be Some(Expired)"
);
}
#[test]
fn tampered_signature_returns_none() {
let tmp = TempDir::new().unwrap();
let mgr = test_manager(&tmp);
let token = mgr.issue(&identity("dave")).expect("issue");
let mut parts: Vec<String> =
token.split('.').map(str::to_owned).collect();
let mut sig_bytes = B64.decode(&parts[2]).expect("base64 decode sig");
sig_bytes[0] ^= 0xff;
parts[2] = B64.encode(&sig_bytes);
let tampered = parts.join(".");
let mut hdrs = hyper::HeaderMap::new();
hdrs.insert(
hyper::header::COOKIE,
format!("sess={tampered}").parse().unwrap(),
);
assert!(
matches!(mgr.validate(&hdrs), Some(JwtResult::Invalid)),
"tampered must be Some(Invalid)"
);
}
#[test]
fn jwks_contains_correct_coordinates() {
let tmp = TempDir::new().unwrap();
let mgr = test_manager(&tmp);
let jwks: serde_json::Value =
serde_json::from_str(&mgr.jwks_json()).unwrap();
let key = &jwks["keys"][0];
assert_eq!(key["kty"], "EC");
assert_eq!(key["crv"], "P-256");
assert_eq!(key["alg"], "ES256");
let x = B64.decode(key["x"].as_str().unwrap()).unwrap();
let y = B64.decode(key["y"].as_str().unwrap()).unwrap();
assert_eq!(x.len(), 32);
assert_eq!(y.len(), 32);
let expected_kid = compute_kid(&mgr.verifying_key);
assert_eq!(key["kid"].as_str().unwrap(), expected_kid);
}
#[test]
fn make_set_cookie_secure_flag() {
let tmp = TempDir::new().unwrap();
let mgr = test_manager(&tmp);
let id = identity("eve");
let plain = mgr.make_set_cookie(&id, false).expect("plain cookie");
assert!(plain.contains("HttpOnly"));
assert!(plain.contains("SameSite=Strict"));
assert!(!plain.contains("Secure"));
let secure = mgr.make_set_cookie(&id, true).expect("secure cookie");
assert!(secure.contains("; Secure"));
}
#[test]
fn key_persists_across_reload() {
let tmp = TempDir::new().unwrap();
let mgr1 = test_manager(&tmp);
let token = mgr1.issue(&identity("frank")).expect("issue");
let mgr2 = test_manager(&tmp);
let mut hdrs = hyper::HeaderMap::new();
hdrs.insert(
hyper::header::COOKIE,
format!("sess={token}").parse().unwrap(),
);
assert!(
matches!(mgr2.validate(&hdrs), Some(JwtResult::Valid(_))),
"reloaded key must accept prior tokens"
);
}
#[test]
fn cache_hit_returns_consistent_identity() {
let tmp = TempDir::new().unwrap();
let mgr = test_manager(&tmp);
let id = identity("zara");
let token = mgr.issue(&id).expect("issue");
let mut hdrs = hyper::HeaderMap::new();
hdrs.insert(
hyper::header::COOKIE,
format!("sess={token}").parse().unwrap(),
);
let first = match mgr.validate(&hdrs).expect("first validate") {
JwtResult::Valid(id) => id,
JwtResult::Invalid | JwtResult::Expired | JwtResult::NotMine => {
panic!("expected Valid")
}
};
let second = match mgr.validate(&hdrs).expect("second validate") {
JwtResult::Valid(id) => id,
JwtResult::Invalid | JwtResult::Expired | JwtResult::NotMine => {
panic!("expected Valid")
}
};
assert_eq!(first.username, second.username);
assert_eq!(first.groups, second.groups);
}
#[test]
fn standalone_mode_does_not_issue() {
let tmp = TempDir::new().unwrap();
let mgr = test_manager(&tmp);
assert!(!mgr.is_session_mode());
}
#[test]
fn session_mode_is_detected() {
let tmp = TempDir::new().unwrap();
let inner: Arc<dyn Authenticator> = Arc::new(AnonymousAuthenticator);
let mgr = JwtManager::load_or_generate(
tmp.path(),
JwtConfig {
cookie_name: "sess".to_owned(),
validity_secs: 300,
},
Some(inner),
)
.unwrap();
assert!(mgr.is_session_mode());
}
}