use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use base64::engine::general_purpose::URL_SAFE_NO_PAD as B64;
use base64::Engine;
use p256::ecdsa::SigningKey as EcdsaSigningKey;
use p256::elliptic_curve::sec1::ToEncodedPoint;
use p256::pkcs8::{EncodePrivateKey, LineEnding};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum JwksError {
#[error("key generation: {0}")]
KeyGen(String),
}
#[derive(Debug, Clone)]
pub struct SigningKey {
pub kid: String,
pub alg: String,
pub private_pem: String,
pub private_der: Vec<u8>,
pub public_jwk: PublicJwk,
pub created_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PublicJwk {
pub kty: String,
pub crv: String,
pub x: String,
pub y: String,
pub kid: String,
pub alg: String,
#[serde(rename = "use")]
pub use_: String,
}
impl SigningKey {
pub fn generate_es256() -> Result<Self, JwksError> {
use rand::rngs::OsRng;
let sk = EcdsaSigningKey::random(&mut OsRng);
let pk = p256::PublicKey::from(sk.verifying_key());
let point = pk.to_encoded_point(false); let x = point
.x()
.ok_or_else(|| JwksError::KeyGen("EC point missing x coordinate".into()))?;
let y = point
.y()
.ok_or_else(|| JwksError::KeyGen("EC point missing y coordinate".into()))?;
let kid = format!("es256-{}", uuid::Uuid::new_v4());
let public_jwk = PublicJwk {
kty: "EC".into(),
crv: "P-256".into(),
x: B64.encode(x.as_ref() as &[u8]),
y: B64.encode(y.as_ref() as &[u8]),
kid: kid.clone(),
alg: "ES256".into(),
use_: "sig".into(),
};
let private_pem = sk
.to_pkcs8_pem(LineEnding::LF)
.map_err(|e| JwksError::KeyGen(format!("PKCS8 PEM encode: {e}")))?
.to_string();
let private_der = sk
.to_pkcs8_der()
.map_err(|e| JwksError::KeyGen(format!("PKCS8 DER encode: {e}")))?
.as_bytes()
.to_vec();
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Ok(Self {
kid,
alg: "ES256".into(),
private_pem,
private_der,
public_jwk,
created_at,
})
}
pub fn from_pem(kid: impl Into<String>, pem: &str) -> Result<Self, JwksError> {
use p256::pkcs8::DecodePrivateKey;
let sk = EcdsaSigningKey::from_pkcs8_pem(pem)
.map_err(|e| JwksError::KeyGen(format!("PKCS8 PEM decode: {e}")))?;
let pk = p256::PublicKey::from(sk.verifying_key());
let point = pk.to_encoded_point(false);
let x = point
.x()
.ok_or_else(|| JwksError::KeyGen("EC point missing x coordinate".into()))?;
let y = point
.y()
.ok_or_else(|| JwksError::KeyGen("EC point missing y coordinate".into()))?;
let kid_s: String = kid.into();
let public_jwk = PublicJwk {
kty: "EC".into(),
crv: "P-256".into(),
x: B64.encode(x.as_ref() as &[u8]),
y: B64.encode(y.as_ref() as &[u8]),
kid: kid_s.clone(),
alg: "ES256".into(),
use_: "sig".into(),
};
let private_der = sk
.to_pkcs8_der()
.map_err(|e| JwksError::KeyGen(format!("PKCS8 DER encode: {e}")))?
.as_bytes()
.to_vec();
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Ok(Self {
kid: kid_s,
alg: "ES256".into(),
private_pem: pem.to_string(),
private_der,
public_jwk,
created_at,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwksDocument {
pub keys: Vec<PublicJwk>,
}
#[derive(Clone)]
pub struct Jwks {
inner: Arc<RwLock<JwksInner>>,
retention: Duration,
}
struct JwksInner {
active: SigningKey,
retired: Vec<(SigningKey, SystemTime)>,
}
impl Jwks {
pub fn generate_es256() -> Result<Self, JwksError> {
let active = SigningKey::generate_es256()?;
Ok(Self::from_active(active))
}
pub fn from_active(active: SigningKey) -> Self {
Self {
inner: Arc::new(RwLock::new(JwksInner {
active,
retired: Vec::new(),
})),
retention: Duration::from_secs(7 * 24 * 3600), }
}
pub fn with_retention(mut self, retention: Duration) -> Self {
self.retention = retention;
self
}
pub fn active_key(&self) -> SigningKey {
self.inner.read().active.clone()
}
pub fn rotate(&self) -> Result<SigningKey, JwksError> {
let new = SigningKey::generate_es256()?;
let mut inner = self.inner.write();
let old = std::mem::replace(&mut inner.active, new.clone());
inner.retired.push((old, SystemTime::now()));
Ok(new)
}
pub fn insert_signing_key(&self, key: SigningKey) {
let mut inner = self.inner.write();
let old = std::mem::replace(&mut inner.active, key);
inner.retired.push((old, SystemTime::now()));
}
pub fn prune_expired(&self) {
let mut inner = self.inner.write();
let retention = self.retention;
inner.retired.retain(|(_, ts)| {
ts.elapsed().unwrap_or(Duration::ZERO) < retention
});
}
pub fn public_document(&self) -> JwksDocument {
let inner = self.inner.read();
let mut keys = Vec::with_capacity(1 + inner.retired.len());
keys.push(inner.active.public_jwk.clone());
for (k, _) in &inner.retired {
keys.push(k.public_jwk.clone());
}
JwksDocument { keys }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generated_jwks_publishes_signing_key() {
let jwks = Jwks::generate_es256().unwrap();
let doc = jwks.public_document();
assert_eq!(doc.keys.len(), 1);
let k = &doc.keys[0];
assert_eq!(k.kty, "EC");
assert_eq!(k.crv, "P-256");
assert_eq!(k.alg, "ES256");
assert_eq!(k.use_, "sig");
assert!(!k.x.is_empty() && !k.y.is_empty());
assert!(k.kid.starts_with("es256-"));
}
#[test]
fn rotate_retains_old_key_for_verification_window() {
let jwks = Jwks::generate_es256().unwrap();
let original_kid = jwks.active_key().kid.clone();
jwks.rotate().unwrap();
let doc = jwks.public_document();
assert_eq!(doc.keys.len(), 2);
assert_ne!(jwks.active_key().kid, original_kid);
assert!(doc.keys.iter().any(|k| k.kid == original_kid));
}
#[test]
fn prune_expired_drops_retired_keys_past_retention() {
let jwks = Jwks::generate_es256().unwrap().with_retention(Duration::from_millis(1));
jwks.rotate().unwrap();
std::thread::sleep(Duration::from_millis(20));
jwks.prune_expired();
let doc = jwks.public_document();
assert_eq!(doc.keys.len(), 1);
}
#[test]
fn key_round_trips_through_pem() {
let k = SigningKey::generate_es256().unwrap();
let k2 = SigningKey::from_pem(&k.kid, &k.private_pem).unwrap();
assert_eq!(k.public_jwk.x, k2.public_jwk.x);
assert_eq!(k.public_jwk.y, k2.public_jwk.y);
}
}