use base64::engine::general_purpose::{STANDARD as BASE64, URL_SAFE_NO_PAD as BASE64URL};
use base64::Engine;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use rsa::traits::PublicKeyParts;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RiseClaims {
pub sub: String,
pub email: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub groups: Option<Vec<String>>,
pub iat: u64,
pub exp: u64,
pub iss: String,
pub aud: String,
}
pub struct JwtSigner {
hs256_encoding_key: EncodingKey,
hs256_decoding_key: DecodingKey,
rs256_encoding_key: Arc<EncodingKey>,
rs256_decoding_key: Arc<DecodingKey>,
rs256_public_key_pem: String,
rs256_key_id: String,
issuer: String,
pub(crate) default_expiry_seconds: u64,
claims_to_include: Vec<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum JwtSignerError {
#[error("Invalid base64 secret: {0}")]
InvalidBase64(#[from] base64::DecodeError),
#[error("JWT signing failed: {0}")]
SigningFailed(#[from] jsonwebtoken::errors::Error),
#[error("System time error: {0}")]
SystemTimeError(#[from] std::time::SystemTimeError),
#[error("Missing required claim: {0}")]
MissingClaim(String),
#[error("RSA key generation failed: {0}")]
RsaKeyError(String),
#[error("PEM encoding failed: {0}")]
PemError(String),
}
impl JwtSigner {
pub fn new(
hs256_secret_base64: &str,
issuer: String,
default_expiry_seconds: u64,
claims_to_include: Vec<String>,
rs256_private_key_pem: Option<&str>,
rs256_public_key_pem: Option<&str>,
) -> Result<Self, JwtSignerError> {
let secret = BASE64.decode(hs256_secret_base64)?;
if secret.len() < 32 {
return Err(JwtSignerError::InvalidBase64(
base64::DecodeError::InvalidLength(secret.len()),
));
}
let hs256_encoding_key = EncodingKey::from_secret(&secret);
let hs256_decoding_key = DecodingKey::from_secret(&secret);
let (rs256_encoding_key, rs256_decoding_key, rs256_public_key_pem, rs256_key_id) = if let (
Some(private_pem),
Some(public_pem),
) =
(rs256_private_key_pem, rs256_public_key_pem)
{
tracing::info!("Using pre-configured RS256 key pair");
let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes()).map_err(|e| {
JwtSignerError::RsaKeyError(format!("Invalid RS256 private key: {}", e))
})?;
let decoding_key = DecodingKey::from_rsa_pem(public_pem.as_bytes()).map_err(|e| {
JwtSignerError::RsaKeyError(format!("Invalid RS256 public key: {}", e))
})?;
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(public_pem.as_bytes());
let hash = hasher.finalize();
let key_id = format!("{:x}", hash)[..16].to_string();
(encoding_key, decoding_key, public_pem.to_string(), key_id)
} else if let Some(private_pem) = rs256_private_key_pem {
tracing::info!("Using pre-configured RS256 private key, deriving public key");
use rsa::pkcs8::{DecodePrivateKey, EncodePublicKey};
use rsa::RsaPrivateKey;
let private_key = RsaPrivateKey::from_pkcs8_pem(private_pem).map_err(|e| {
JwtSignerError::RsaKeyError(format!("Invalid RS256 private key PEM: {}", e))
})?;
let public_key = rsa::RsaPublicKey::from(&private_key);
let public_key_pem = public_key
.to_public_key_pem(rsa::pkcs8::LineEnding::LF)
.map_err(|e| JwtSignerError::PemError(e.to_string()))?;
let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes())
.map_err(|e| JwtSignerError::RsaKeyError(e.to_string()))?;
let decoding_key = DecodingKey::from_rsa_pem(public_key_pem.as_bytes())
.map_err(|e| JwtSignerError::RsaKeyError(e.to_string()))?;
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(public_key_pem.as_bytes());
let hash = hasher.finalize();
let key_id = format!("{:x}", hash)[..16].to_string();
(encoding_key, decoding_key, public_key_pem, key_id)
} else {
tracing::warn!("No RS256 keys configured - generating new key pair. JWTs will be invalidated on restart. Configure rs256_private_key_pem to persist keys.");
use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey};
use rsa::{RsaPrivateKey, RsaPublicKey};
let mut rng = rand::thread_rng();
let bits = 2048;
let private_key = RsaPrivateKey::new(&mut rng, bits)
.map_err(|e| JwtSignerError::RsaKeyError(e.to_string()))?;
let public_key = RsaPublicKey::from(&private_key);
let private_key_pem = private_key
.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
.map_err(|e| JwtSignerError::PemError(e.to_string()))?
.to_string();
let public_key_pem = public_key
.to_public_key_pem(rsa::pkcs8::LineEnding::LF)
.map_err(|e| JwtSignerError::PemError(e.to_string()))?;
let encoding_key = EncodingKey::from_rsa_pem(private_key_pem.as_bytes())
.map_err(|e| JwtSignerError::RsaKeyError(e.to_string()))?;
let decoding_key = DecodingKey::from_rsa_pem(public_key_pem.as_bytes())
.map_err(|e| JwtSignerError::RsaKeyError(e.to_string()))?;
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(public_key_pem.as_bytes());
let hash = hasher.finalize();
let key_id = format!("{:x}", hash)[..16].to_string();
(encoding_key, decoding_key, public_key_pem, key_id)
};
Ok(Self {
hs256_encoding_key,
hs256_decoding_key,
rs256_encoding_key: Arc::new(rs256_encoding_key),
rs256_decoding_key: Arc::new(rs256_decoding_key),
rs256_public_key_pem,
rs256_key_id,
issuer,
default_expiry_seconds,
claims_to_include,
})
}
pub fn generate_jwks(&self) -> Result<serde_json::Value, JwtSignerError> {
use rsa::pkcs8::DecodePublicKey;
use rsa::RsaPublicKey;
let public_key = RsaPublicKey::from_public_key_pem(&self.rs256_public_key_pem)
.map_err(|e| JwtSignerError::PemError(e.to_string()))?;
let n = BASE64URL.encode(public_key.n().to_bytes_be());
let e = BASE64URL.encode(public_key.e().to_bytes_be());
Ok(serde_json::json!({
"keys": [{
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"kid": self.rs256_key_id,
"n": n,
"e": e,
}]
}))
}
pub async fn sign_user_jwt(
&self,
idp_claims: &serde_json::Value,
user_id: uuid::Uuid,
db_pool: &sqlx::PgPool,
rise_public_url: &str,
expiry_override: Option<u64>,
) -> Result<String, JwtSignerError> {
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let exp = expiry_override.unwrap_or_else(|| now + self.default_expiry_seconds);
let sub = idp_claims
.get("sub")
.and_then(|v| v.as_str())
.ok_or_else(|| JwtSignerError::MissingClaim("sub".to_string()))?
.to_string();
let email = idp_claims
.get("email")
.and_then(|v| v.as_str())
.ok_or_else(|| JwtSignerError::MissingClaim("email".to_string()))?
.to_string();
let name = if self.claims_to_include.contains(&"name".to_string()) {
idp_claims
.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
} else {
None
};
let groups = crate::db::teams::get_team_names_for_user(db_pool, user_id)
.await
.ok();
let claims = RiseClaims {
sub,
email,
name,
groups,
iat: now,
exp,
iss: self.issuer.clone(),
aud: rise_public_url.to_string(),
};
let header = Header::new(Algorithm::HS256);
let token = encode(&header, &claims, &self.hs256_encoding_key)?;
Ok(token)
}
pub async fn sign_ingress_jwt(
&self,
idp_claims: &serde_json::Value,
user_id: uuid::Uuid,
db_pool: &sqlx::PgPool,
project_url: &str,
expiry_override: Option<u64>,
) -> Result<String, JwtSignerError> {
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let exp = expiry_override.unwrap_or_else(|| now + self.default_expiry_seconds);
let sub = idp_claims
.get("sub")
.and_then(|v| v.as_str())
.ok_or_else(|| JwtSignerError::MissingClaim("sub".to_string()))?
.to_string();
let email = idp_claims
.get("email")
.and_then(|v| v.as_str())
.ok_or_else(|| JwtSignerError::MissingClaim("email".to_string()))?
.to_string();
let name = if self.claims_to_include.contains(&"name".to_string()) {
idp_claims
.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
} else {
None
};
let groups = crate::db::teams::get_team_names_for_user(db_pool, user_id)
.await
.ok();
let claims = RiseClaims {
sub,
email,
name,
groups,
iat: now,
exp,
iss: self.issuer.clone(),
aud: project_url.to_string(),
};
let mut header = Header::new(Algorithm::RS256);
header.kid = Some(self.rs256_key_id.clone());
let token = encode(&header, &claims, &self.rs256_encoding_key)?;
Ok(token)
}
pub fn verify_jwt_skip_aud(&self, token: &str) -> Result<RiseClaims, JwtSignerError> {
let header = jsonwebtoken::decode_header(token)?;
match header.alg {
Algorithm::HS256 => {
let mut validation = Validation::new(Algorithm::HS256);
validation.set_issuer(&[&self.issuer]);
validation.validate_aud = false;
let token_data =
decode::<RiseClaims>(token, &self.hs256_decoding_key, &validation)?;
Ok(token_data.claims)
}
Algorithm::RS256 => {
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[&self.issuer]);
validation.validate_aud = false;
let token_data =
decode::<RiseClaims>(token, &self.rs256_decoding_key, &validation)?;
Ok(token_data.claims)
}
_ => Err(JwtSignerError::SigningFailed(
jsonwebtoken::errors::Error::from(
jsonwebtoken::errors::ErrorKind::InvalidAlgorithm,
),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_signer() -> JwtSigner {
let secret = BASE64.encode([0u8; 32]);
JwtSigner::new(
&secret,
"https://rise.test".to_string(),
3600,
vec!["sub".to_string(), "email".to_string(), "name".to_string()],
None, None,
)
.unwrap()
}
#[test]
fn test_create_signer() {
let signer = create_test_signer();
assert!(!signer.rs256_public_key_pem.is_empty());
assert!(!signer.rs256_key_id.is_empty());
assert_eq!(signer.rs256_key_id.len(), 16); }
#[test]
fn test_generate_jwks() {
let signer = create_test_signer();
let jwks = signer.generate_jwks().unwrap();
assert!(jwks.get("keys").is_some());
let keys = jwks.get("keys").unwrap().as_array().unwrap();
assert_eq!(keys.len(), 1);
let key = &keys[0];
assert_eq!(key.get("kty").unwrap().as_str().unwrap(), "RSA");
assert_eq!(key.get("use").unwrap().as_str().unwrap(), "sig");
assert_eq!(key.get("alg").unwrap().as_str().unwrap(), "RS256");
assert_eq!(
key.get("kid").unwrap().as_str().unwrap(),
&signer.rs256_key_id
);
assert!(key.get("n").is_some());
assert!(key.get("e").is_some());
}
#[test]
fn test_verify_rs256_jwt() {
let signer = create_test_signer();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = RiseClaims {
sub: "user456".to_string(),
email: "user2@example.com".to_string(),
name: None,
groups: None,
iat: now,
exp: now + 3600,
iss: "https://rise.test".to_string(),
aud: "https://myapp.apps.rise.dev".to_string(),
};
let mut header = Header::new(Algorithm::RS256);
header.kid = Some(signer.rs256_key_id.to_string());
let token = encode(&header, &claims, &signer.rs256_encoding_key).unwrap();
let verified_claims = signer.verify_jwt_skip_aud(&token).unwrap();
assert_eq!(verified_claims.sub, "user456");
assert_eq!(verified_claims.email, "user2@example.com");
assert_eq!(verified_claims.aud, "https://myapp.apps.rise.dev");
}
#[test]
fn test_invalid_secret_length() {
let short_secret = BASE64.encode(b"short");
let result = JwtSigner::new(
&short_secret,
"https://rise.test".to_string(),
3600,
vec!["sub".to_string(), "email".to_string()],
None,
None,
);
assert!(result.is_err());
}
}