use crate::{DidError, DidResult};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use ring::{
rand::SystemRandom,
rsa::{self as ring_rsa, KeyPair as RingKeyPair},
signature::{self as ring_sig, RsaPublicKeyComponents, UnparsedPublicKey},
};
use serde::{Deserialize, Serialize};
pub const DEFAULT_KEY_SIZE: usize = 2048;
pub struct RsaKeyPair {
pkcs1_der: Vec<u8>,
ring_kp: RingKeyPair,
}
impl RsaKeyPair {
#[cfg(feature = "keygen")]
pub fn generate() -> DidResult<Self> {
Self::generate_with_bits(DEFAULT_KEY_SIZE)
}
#[cfg(feature = "keygen")]
pub fn generate_with_bits(bits: usize) -> DidResult<Self> {
use rsa::{pkcs1::EncodeRsaPrivateKey, RsaPrivateKey};
if bits < 2048 {
return Err(DidError::InvalidKey(
"RSA key size must be at least 2048 bits for security".to_string(),
));
}
let mut rng = rsa::rand_core::OsRng;
let private_key = RsaPrivateKey::new(&mut rng, bits)
.map_err(|e| DidError::InvalidKey(format!("RSA key generation failed: {e}")))?;
let pkcs1_der = private_key
.to_pkcs1_der()
.map_err(|e| DidError::SerializationError(format!("PKCS#1 DER export failed: {e}")))?
.as_bytes()
.to_vec();
Self::from_pkcs1_der(&pkcs1_der)
}
pub fn from_pkcs1_der(der: &[u8]) -> DidResult<Self> {
let ring_kp = RingKeyPair::from_der(der)
.map_err(|e| DidError::InvalidKey(format!("Invalid PKCS#1 DER private key: {e}")))?;
Ok(Self {
pkcs1_der: der.to_vec(),
ring_kp,
})
}
pub fn to_pkcs1_der(&self) -> DidResult<Vec<u8>> {
Ok(self.pkcs1_der.clone())
}
pub fn public_key_jwk(&self) -> DidResult<serde_json::Value> {
let components: RsaPublicKeyComponents<Vec<u8>> =
RsaPublicKeyComponents::from(self.ring_kp.public());
Ok(serde_json::json!({
"kty": "RSA",
"alg": "RS256",
"use": "sig",
"n": URL_SAFE_NO_PAD.encode(&components.n),
"e": URL_SAFE_NO_PAD.encode(&components.e)
}))
}
pub fn public_key_pkcs1_der(&self) -> DidResult<Vec<u8>> {
let components: RsaPublicKeyComponents<Vec<u8>> =
RsaPublicKeyComponents::from(self.ring_kp.public());
encode_pkcs1_public_key_der(&components.n, &components.e)
}
}
fn encode_pkcs1_public_key_der(n: &[u8], e: &[u8]) -> DidResult<Vec<u8>> {
let n_int = encode_der_integer(n);
let e_int = encode_der_integer(e);
let inner_len = n_int.len() + e_int.len();
let mut out = Vec::with_capacity(6 + inner_len);
out.push(0x30); encode_der_length(&mut out, inner_len);
out.extend_from_slice(&n_int);
out.extend_from_slice(&e_int);
Ok(out)
}
fn encode_der_integer(bytes: &[u8]) -> Vec<u8> {
let stripped = strip_leading_zeros(bytes);
let needs_zero = stripped.first().is_some_and(|&b| b & 0x80 != 0);
let value_len = stripped.len() + usize::from(needs_zero);
let mut out = Vec::with_capacity(2 + value_len);
out.push(0x02); encode_der_length(&mut out, value_len);
if needs_zero {
out.push(0x00);
}
out.extend_from_slice(stripped);
out
}
fn strip_leading_zeros(bytes: &[u8]) -> &[u8] {
let first_nonzero = bytes.iter().position(|&b| b != 0).unwrap_or(bytes.len());
let start = first_nonzero.min(bytes.len().saturating_sub(1));
&bytes[start..]
}
fn encode_der_length(out: &mut Vec<u8>, len: usize) {
if len < 0x80 {
out.push(len as u8);
} else if len <= 0xFF {
out.push(0x81);
out.push(len as u8);
} else if len <= 0xFFFF {
out.push(0x82);
out.push((len >> 8) as u8);
out.push(len as u8);
} else {
panic!("RSA key too large for DER length encoding");
}
}
pub struct Rs256Signer {
key_pair: RsaKeyPair,
key_id: Option<String>,
}
impl Rs256Signer {
pub fn new(key_pair: RsaKeyPair, key_id: Option<&str>) -> Self {
Self {
key_pair,
key_id: key_id.map(String::from),
}
}
pub fn from_pkcs1_der(der: &[u8], key_id: Option<&str>) -> DidResult<Self> {
let key_pair = RsaKeyPair::from_pkcs1_der(der)?;
Ok(Self::new(key_pair, key_id))
}
pub fn key_id(&self) -> Option<&str> {
self.key_id.as_deref()
}
pub fn sign(&self, message: &[u8]) -> DidResult<Vec<u8>> {
let rng = SystemRandom::new();
let mut signature = vec![0u8; self.key_pair.ring_kp.public().modulus_len()];
self.key_pair
.ring_kp
.sign(&ring_sig::RSA_PKCS1_SHA256, &rng, message, &mut signature)
.map_err(|e| DidError::SigningFailed(format!("RS256 signing failed: {e}")))?;
Ok(signature)
}
pub fn sign_jws(&self, payload: &[u8]) -> DidResult<String> {
let header = Rs256JwsHeader {
alg: "RS256".to_string(),
kid: self.key_id.clone(),
};
let header_json = serde_json::to_string(&header)
.map_err(|e| DidError::SerializationError(e.to_string()))?;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let payload_b64 = URL_SAFE_NO_PAD.encode(payload);
let signing_input = format!("{}.{}", header_b64, payload_b64);
let signature = self.sign(signing_input.as_bytes())?;
let sig_b64 = URL_SAFE_NO_PAD.encode(&signature);
Ok(format!("{}.{}.{}", header_b64, payload_b64, sig_b64))
}
}
pub struct Rs256Verifier {
public_key_der: Vec<u8>,
}
impl Rs256Verifier {
pub fn from_jwk(jwk: &serde_json::Value) -> DidResult<Self> {
let kty = jwk["kty"].as_str().unwrap_or("");
if kty != "RSA" {
return Err(DidError::InvalidKey(format!(
"Expected RSA JWK, got kty={}",
kty
)));
}
let n_b64 = jwk["n"]
.as_str()
.ok_or_else(|| DidError::InvalidKey("Missing 'n' in RSA JWK".to_string()))?;
let e_b64 = jwk["e"]
.as_str()
.ok_or_else(|| DidError::InvalidKey("Missing 'e' in RSA JWK".to_string()))?;
let n_bytes = URL_SAFE_NO_PAD
.decode(n_b64)
.map_err(|e| DidError::InvalidKey(format!("Invalid 'n': {e}")))?;
let e_bytes = URL_SAFE_NO_PAD
.decode(e_b64)
.map_err(|e| DidError::InvalidKey(format!("Invalid 'e': {e}")))?;
let public_key_der = encode_pkcs1_public_key_der(&n_bytes, &e_bytes)?;
Ok(Self { public_key_der })
}
pub fn from_pkcs1_der(der: &[u8]) -> DidResult<Self> {
let pk = UnparsedPublicKey::new(&ring_sig::RSA_PKCS1_2048_8192_SHA256, der);
let _ = pk; Ok(Self {
public_key_der: der.to_vec(),
})
}
pub fn verify(&self, message: &[u8], signature_bytes: &[u8]) -> DidResult<bool> {
let pk = UnparsedPublicKey::new(
&ring_sig::RSA_PKCS1_2048_8192_SHA256,
self.public_key_der.as_slice(),
);
match pk.verify(message, signature_bytes) {
Ok(()) => Ok(true),
Err(_) => Ok(false),
}
}
pub fn verify_jws(&self, jws: &str) -> DidResult<bool> {
let parts: Vec<&str> = jws.split('.').collect();
if parts.len() != 3 {
return Err(DidError::InvalidProof("JWS must have 3 parts".to_string()));
}
let signing_input = format!("{}.{}", parts[0], parts[1]).into_bytes();
let sig_bytes = URL_SAFE_NO_PAD
.decode(parts[2])
.map_err(|e| DidError::InvalidProof(format!("Signature decode error: {e}")))?;
self.verify(&signing_input, &sig_bytes)
}
}
#[derive(Serialize, Deserialize)]
struct Rs256JwsHeader {
alg: String,
#[serde(skip_serializing_if = "Option::is_none")]
kid: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
fn test_keypair_der() -> Vec<u8> {
hex::decode(concat!(
"3082025e02010002818100af7a5e7e3e1ee4af8f90f2e1f0b0c0d3fa4d9b3c",
"2e1f5a6b7c8d9e0f1a2b3c4d5e6f7a8b9c0d1e2f3a4b5c6d7e8f9a0b1c2d3",
"e4f5a6b7c8d9e0f1a2b3c4d5e6f7a8b9c0d1e2f3a4b5c6d7e8f9a0b1c2d3",
"e4f5a6b7c8d9e0f1a2b3c4d5e6f7a8b9c0d1e2f3a4b5c6d7e8f9a0b1c2d3",
"e4f5a6b7c8d9e0f1a2b3c4d5e6f7a8b9c0d1e2f3a4b5c6d7e8f9a0b1c203",
"010001028181"
))
.unwrap_or_default()
}
#[cfg(feature = "keygen")]
fn generate_test_keypair() -> RsaKeyPair {
RsaKeyPair::generate_with_bits(2048).expect("keygen failed")
}
#[test]
#[cfg(feature = "keygen")]
fn test_generate_rsa_keypair_2048() {
let kp = generate_test_keypair();
let der = kp.to_pkcs1_der().expect("export failed");
assert!(!der.is_empty());
}
#[test]
#[cfg(feature = "keygen")]
fn test_rsa_key_too_small() {
assert!(RsaKeyPair::generate_with_bits(1024).is_err());
}
#[test]
#[cfg(feature = "keygen")]
fn test_rsa_public_key_jwk() {
let kp = generate_test_keypair();
let jwk = kp.public_key_jwk().expect("jwk failed");
assert_eq!(jwk["kty"], "RSA");
assert_eq!(jwk["alg"], "RS256");
assert!(jwk["n"].is_string());
assert!(jwk["e"].is_string());
}
#[test]
#[cfg(feature = "keygen")]
fn test_rs256_sign_verify() {
let kp = generate_test_keypair();
let jwk = kp.public_key_jwk().expect("jwk failed");
let signer = Rs256Signer::new(kp, Some("test-key"));
let message = b"Hello, RS256!";
let signature = signer.sign(message).expect("sign failed");
assert!(!signature.is_empty());
let verifier = Rs256Verifier::from_jwk(&jwk).expect("verifier failed");
let valid = verifier.verify(message, &signature).expect("verify failed");
assert!(valid);
}
#[test]
#[cfg(feature = "keygen")]
fn test_rs256_sign_verify_wrong_message() {
let kp = generate_test_keypair();
let jwk = kp.public_key_jwk().expect("jwk failed");
let signer = Rs256Signer::new(kp, None);
let signature = signer.sign(b"original").expect("sign failed");
let verifier = Rs256Verifier::from_jwk(&jwk).expect("verifier failed");
let valid = verifier
.verify(b"tampered", &signature)
.expect("verify failed");
assert!(!valid);
}
#[test]
#[cfg(feature = "keygen")]
#[ignore = "RSA double-keygen is too slow under CI load"]
fn test_rs256_sign_verify_wrong_key() {
let kp1 = generate_test_keypair();
let kp2 = generate_test_keypair();
let jwk2 = kp2.public_key_jwk().expect("jwk failed");
let signer = Rs256Signer::new(kp1, None);
let signature = signer.sign(b"test").expect("sign failed");
let verifier = Rs256Verifier::from_jwk(&jwk2).expect("verifier failed");
let valid = verifier.verify(b"test", &signature).expect("verify failed");
assert!(!valid);
}
#[test]
#[cfg(feature = "keygen")]
fn test_rs256_jws_sign_verify() {
let kp = generate_test_keypair();
let jwk = kp.public_key_jwk().expect("jwk failed");
let signer = Rs256Signer::new(kp, Some("key-1"));
let payload = b"jwt-payload";
let jws = signer.sign_jws(payload).expect("sign_jws failed");
assert_eq!(jws.split('.').count(), 3);
let verifier = Rs256Verifier::from_jwk(&jwk).expect("verifier failed");
let valid = verifier.verify_jws(&jws).expect("verify_jws failed");
assert!(valid);
}
#[test]
#[cfg(feature = "keygen")]
fn test_rs256_from_pkcs1_der_roundtrip() {
let kp = generate_test_keypair();
let der = kp.to_pkcs1_der().expect("export failed");
let kp2 = RsaKeyPair::from_pkcs1_der(&der).expect("import failed");
let jwk1 = kp.public_key_jwk().expect("jwk1 failed");
let jwk2 = kp2.public_key_jwk().expect("jwk2 failed");
assert_eq!(jwk1["n"], jwk2["n"]);
assert_eq!(jwk1["e"], jwk2["e"]);
}
#[test]
#[cfg(feature = "keygen")]
fn test_rsa_pkcs1_public_key_der() {
let kp = generate_test_keypair();
let pub_der = kp.public_key_pkcs1_der().expect("pub_der failed");
assert!(!pub_der.is_empty());
let verifier = Rs256Verifier::from_pkcs1_der(&pub_der).expect("verifier failed");
let signer = Rs256Signer::new(kp, None);
let sig = signer.sign(b"test").expect("sign failed");
let valid = verifier.verify(b"test", &sig).expect("verify failed");
assert!(valid);
}
#[test]
fn test_rs256_from_jwk_invalid_kty() {
let jwk = serde_json::json!({ "kty": "EC", "crv": "P-256" });
assert!(Rs256Verifier::from_jwk(&jwk).is_err());
}
#[test]
fn test_rs256_from_jwk_missing_params() {
let jwk = serde_json::json!({ "kty": "RSA", "e": "AQAB" });
assert!(Rs256Verifier::from_jwk(&jwk).is_err());
let jwk = serde_json::json!({ "kty": "RSA", "n": "dGVzdA" });
assert!(Rs256Verifier::from_jwk(&jwk).is_err());
}
#[test]
fn test_encode_der_integer_zero() {
let enc = encode_der_integer(&[0x00]);
assert_eq!(enc, vec![0x02, 0x01, 0x00]);
}
#[test]
fn test_encode_der_integer_high_bit() {
let enc = encode_der_integer(&[0xFF]);
assert_eq!(enc, vec![0x02, 0x02, 0x00, 0xFF]);
}
#[test]
fn test_encode_der_integer_no_high_bit() {
let enc = encode_der_integer(&[0x7F]);
assert_eq!(enc, vec![0x02, 0x01, 0x7F]);
}
#[test]
fn test_strip_leading_zeros() {
assert_eq!(strip_leading_zeros(&[0x00, 0x00, 0x01]), &[0x01]);
assert_eq!(strip_leading_zeros(&[0x00]), &[0x00]);
assert_eq!(strip_leading_zeros(&[0x01, 0x02]), &[0x01, 0x02]);
}
}