use crate::signing::{KeyPair, PublicKey, SecretKey};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum KeyFormatError {
#[error("Invalid DER encoding: {0}")]
InvalidDer(String),
#[error("Invalid JWK format: {0}")]
InvalidJwk(String),
#[error("Unsupported key type: {0}")]
UnsupportedKeyType(String),
#[error("Invalid key length: expected {expected}, got {actual}")]
InvalidKeyLength { expected: usize, actual: usize },
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Missing required field: {0}")]
MissingField(String),
}
pub type KeyFormatResult<T> = Result<T, KeyFormatError>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct JwkKey {
pub kty: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub crv: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub d: Option<String>,
#[serde(rename = "use", skip_serializing_if = "Option::is_none")]
pub key_use: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub kid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub alg: Option<String>,
}
impl JwkKey {
pub fn from_ed25519_keypair(keypair: &KeyPair) -> Self {
let public_key = keypair.public_key();
let secret_key = keypair.secret_key();
Self {
kty: "OKP".to_string(),
crv: Some("Ed25519".to_string()),
x: Some(base64_url_encode(&public_key)),
d: Some(base64_url_encode(&secret_key)),
key_use: Some("sig".to_string()),
kid: None,
alg: Some("EdDSA".to_string()),
}
}
pub fn from_ed25519_public_key(public_key: &PublicKey) -> Self {
Self {
kty: "OKP".to_string(),
crv: Some("Ed25519".to_string()),
x: Some(base64_url_encode(public_key)),
d: None,
key_use: Some("sig".to_string()),
kid: None,
alg: Some("EdDSA".to_string()),
}
}
pub fn to_ed25519_keypair(&self) -> KeyFormatResult<KeyPair> {
if self.kty != "OKP" {
return Err(KeyFormatError::UnsupportedKeyType(self.kty.clone()));
}
if let Some(crv) = &self.crv {
if crv != "Ed25519" {
return Err(KeyFormatError::UnsupportedKeyType(crv.clone()));
}
}
let x = self
.x
.as_ref()
.ok_or_else(|| KeyFormatError::MissingField("x".to_string()))?;
let public_bytes =
base64_url_decode(x).map_err(|e| KeyFormatError::InvalidJwk(e.to_string()))?;
if public_bytes.len() != 32 {
return Err(KeyFormatError::InvalidKeyLength {
expected: 32,
actual: public_bytes.len(),
});
}
let d = self
.d
.as_ref()
.ok_or_else(|| KeyFormatError::MissingField("d".to_string()))?;
let secret_bytes =
base64_url_decode(d).map_err(|e| KeyFormatError::InvalidJwk(e.to_string()))?;
if secret_bytes.len() != 32 {
return Err(KeyFormatError::InvalidKeyLength {
expected: 32,
actual: secret_bytes.len(),
});
}
let mut secret_key = [0u8; 32];
secret_key.copy_from_slice(&secret_bytes);
KeyPair::from_secret_key(&secret_key)
.map_err(|_| KeyFormatError::InvalidJwk("Invalid secret key".to_string()))
}
pub fn to_ed25519_public_key(&self) -> KeyFormatResult<PublicKey> {
if self.kty != "OKP" {
return Err(KeyFormatError::UnsupportedKeyType(self.kty.clone()));
}
let x = self
.x
.as_ref()
.ok_or_else(|| KeyFormatError::MissingField("x".to_string()))?;
let public_bytes =
base64_url_decode(x).map_err(|e| KeyFormatError::InvalidJwk(e.to_string()))?;
if public_bytes.len() != 32 {
return Err(KeyFormatError::InvalidKeyLength {
expected: 32,
actual: public_bytes.len(),
});
}
let mut public_key = [0u8; 32];
public_key.copy_from_slice(&public_bytes);
Ok(public_key)
}
pub fn to_json(&self) -> KeyFormatResult<String> {
serde_json::to_string_pretty(self)
.map_err(|e| KeyFormatError::SerializationError(e.to_string()))
}
pub fn from_json(json: &str) -> KeyFormatResult<Self> {
serde_json::from_str(json).map_err(|e| KeyFormatError::SerializationError(e.to_string()))
}
pub fn with_kid(mut self, kid: impl Into<String>) -> Self {
self.kid = Some(kid.into());
self
}
}
pub struct DerKey;
impl DerKey {
pub fn encode_ed25519_public_key(public_key: &PublicKey) -> Vec<u8> {
let mut der = Vec::with_capacity(44);
der.push(0x30);
der.push(42);
der.push(0x30);
der.push(5);
der.push(0x06);
der.push(3);
der.extend_from_slice(&[0x2B, 0x65, 0x70]);
der.push(0x03);
der.push(33); der.push(0x00);
der.extend_from_slice(public_key);
der
}
pub fn decode_ed25519_public_key(der: &[u8]) -> KeyFormatResult<PublicKey> {
if der.len() < 44 {
return Err(KeyFormatError::InvalidDer("DER data too short".to_string()));
}
if der[0] != 0x30 {
return Err(KeyFormatError::InvalidDer(
"Expected SEQUENCE tag".to_string(),
));
}
let key_start = der.len() - 32;
if key_start >= der.len() {
return Err(KeyFormatError::InvalidDer(
"Invalid DER structure".to_string(),
));
}
let mut public_key = [0u8; 32];
public_key.copy_from_slice(&der[key_start..]);
Ok(public_key)
}
pub fn encode_ed25519_private_key(secret_key: &SecretKey) -> Vec<u8> {
let mut der = Vec::with_capacity(48);
der.push(0x30);
der.push(46);
der.push(0x02);
der.push(0x01);
der.push(0x00);
der.push(0x30);
der.push(5);
der.push(0x06);
der.push(3);
der.extend_from_slice(&[0x2B, 0x65, 0x70]);
der.push(0x04);
der.push(34);
der.push(0x04);
der.push(32);
der.extend_from_slice(secret_key);
der
}
pub fn decode_ed25519_private_key(der: &[u8]) -> KeyFormatResult<SecretKey> {
if der.len() < 48 {
return Err(KeyFormatError::InvalidDer(
"DER data too short for private key".to_string(),
));
}
let key_start = der.len() - 32;
if key_start >= der.len() {
return Err(KeyFormatError::InvalidDer(
"Invalid DER structure".to_string(),
));
}
let mut secret_key = [0u8; 32];
secret_key.copy_from_slice(&der[key_start..]);
Ok(secret_key)
}
}
fn base64_url_encode(data: &[u8]) -> String {
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.encode(data)
}
fn base64_url_decode(data: &str) -> Result<Vec<u8>, String> {
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.decode(data).map_err(|e| e.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwk_keypair_roundtrip() {
let keypair = KeyPair::generate();
let jwk = JwkKey::from_ed25519_keypair(&keypair);
let restored = jwk.to_ed25519_keypair().unwrap();
assert_eq!(keypair.public_key(), restored.public_key());
assert_eq!(keypair.secret_key(), restored.secret_key());
}
#[test]
fn test_jwk_public_key_roundtrip() {
let keypair = KeyPair::generate();
let public_key = keypair.public_key();
let jwk = JwkKey::from_ed25519_public_key(&public_key);
let restored = jwk.to_ed25519_public_key().unwrap();
assert_eq!(public_key, restored);
}
#[test]
fn test_jwk_json_serialization() {
let keypair = KeyPair::generate();
let jwk = JwkKey::from_ed25519_keypair(&keypair);
let json = jwk.to_json().unwrap();
let restored = JwkKey::from_json(&json).unwrap();
assert_eq!(jwk, restored);
}
#[test]
fn test_jwk_with_kid() {
let keypair = KeyPair::generate();
let jwk = JwkKey::from_ed25519_keypair(&keypair).with_kid("my-key-id");
assert_eq!(jwk.kid, Some("my-key-id".to_string()));
}
#[test]
fn test_jwk_validation() {
let invalid_jwk = JwkKey {
kty: "RSA".to_string(),
crv: None,
x: None,
d: None,
key_use: None,
kid: None,
alg: None,
};
assert!(invalid_jwk.to_ed25519_keypair().is_err());
}
#[test]
fn test_der_public_key_roundtrip() {
let keypair = KeyPair::generate();
let public_key = keypair.public_key();
let der = DerKey::encode_ed25519_public_key(&public_key);
let restored = DerKey::decode_ed25519_public_key(&der).unwrap();
assert_eq!(public_key, restored);
}
#[test]
fn test_der_private_key_roundtrip() {
let keypair = KeyPair::generate();
let secret_key = keypair.secret_key();
let der = DerKey::encode_ed25519_private_key(&secret_key);
let restored = DerKey::decode_ed25519_private_key(&der).unwrap();
assert_eq!(secret_key, restored);
}
#[test]
fn test_der_public_key_structure() {
let keypair = KeyPair::generate();
let der = DerKey::encode_ed25519_public_key(&keypair.public_key());
assert_eq!(der[0], 0x30);
assert!(der.windows(3).any(|w| w == [0x2B, 0x65, 0x70]));
}
#[test]
fn test_base64_url_encoding() {
let data = b"Hello, World!";
let encoded = base64_url_encode(data);
let decoded = base64_url_decode(&encoded).unwrap();
assert_eq!(data, &decoded[..]);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(!encoded.contains('='));
}
#[test]
fn test_jwk_missing_fields() {
let jwk = JwkKey {
kty: "OKP".to_string(),
crv: Some("Ed25519".to_string()),
x: None, d: Some("test".to_string()),
key_use: None,
kid: None,
alg: None,
};
assert!(jwk.to_ed25519_keypair().is_err());
}
}