use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use rand_core::OsRng;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::error::CryptoError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Curve {
X25519,
P256,
K256,
P384,
P521,
}
impl Curve {
pub fn jwk_crv(&self) -> &'static str {
match self {
Curve::X25519 => "X25519",
Curve::P256 => "P-256",
Curve::K256 => "secp256k1",
Curve::P384 => "P-384",
Curve::P521 => "P-521",
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum PublicKeyAgreement {
X25519([u8; 32]),
P256(p256::PublicKey),
K256(k256::PublicKey),
P384(p384::PublicKey),
P521(p521::PublicKey),
}
impl PublicKeyAgreement {
pub fn curve(&self) -> Curve {
match self {
PublicKeyAgreement::X25519(_) => Curve::X25519,
PublicKeyAgreement::P256(_) => Curve::P256,
PublicKeyAgreement::K256(_) => Curve::K256,
PublicKeyAgreement::P384(_) => Curve::P384,
PublicKeyAgreement::P521(_) => Curve::P521,
}
}
pub fn to_jwk(&self) -> Value {
match self {
PublicKeyAgreement::X25519(bytes) => serde_json::json!({
"kty": "OKP",
"crv": "X25519",
"x": URL_SAFE_NO_PAD.encode(bytes),
}),
PublicKeyAgreement::P256(pk) => {
use p256::elliptic_curve::sec1::ToEncodedPoint;
let point = pk.to_encoded_point(false);
serde_json::json!({
"kty": "EC",
"crv": "P-256",
"x": URL_SAFE_NO_PAD.encode(point.x().unwrap()),
"y": URL_SAFE_NO_PAD.encode(point.y().unwrap()),
})
}
PublicKeyAgreement::K256(pk) => {
use k256::elliptic_curve::sec1::ToEncodedPoint;
let point = pk.to_encoded_point(false);
serde_json::json!({
"kty": "EC",
"crv": "secp256k1",
"x": URL_SAFE_NO_PAD.encode(point.x().unwrap()),
"y": URL_SAFE_NO_PAD.encode(point.y().unwrap()),
})
}
PublicKeyAgreement::P384(pk) => {
use p384::elliptic_curve::sec1::ToEncodedPoint;
let point = pk.to_encoded_point(false);
serde_json::json!({
"kty": "EC",
"crv": "P-384",
"x": URL_SAFE_NO_PAD.encode(point.x().unwrap()),
"y": URL_SAFE_NO_PAD.encode(point.y().unwrap()),
})
}
PublicKeyAgreement::P521(pk) => {
use p521::elliptic_curve::sec1::ToEncodedPoint;
let point = pk.to_encoded_point(false);
serde_json::json!({
"kty": "EC",
"crv": "P-521",
"x": URL_SAFE_NO_PAD.encode(point.x().unwrap()),
"y": URL_SAFE_NO_PAD.encode(point.y().unwrap()),
})
}
}
}
pub fn to_public_bytes(&self) -> Vec<u8> {
match self {
PublicKeyAgreement::X25519(bytes) => bytes.to_vec(),
PublicKeyAgreement::P256(pk) => {
use p256::elliptic_curve::sec1::ToEncodedPoint;
pk.to_encoded_point(true).as_bytes().to_vec()
}
PublicKeyAgreement::K256(pk) => {
use k256::elliptic_curve::sec1::ToEncodedPoint;
pk.to_encoded_point(true).as_bytes().to_vec()
}
PublicKeyAgreement::P384(pk) => {
use p384::elliptic_curve::sec1::ToEncodedPoint;
pk.to_encoded_point(true).as_bytes().to_vec()
}
PublicKeyAgreement::P521(pk) => {
use p521::elliptic_curve::sec1::ToEncodedPoint;
pk.to_encoded_point(true).as_bytes().to_vec()
}
}
}
pub fn from_raw_bytes(curve: Curve, bytes: &[u8]) -> Result<Self, CryptoError> {
match curve {
Curve::X25519 => {
let arr: [u8; 32] = bytes.try_into().map_err(|_| {
CryptoError::KeyAgreement("X25519 public key must be 32 bytes".into())
})?;
Ok(PublicKeyAgreement::X25519(arr))
}
Curve::P256 => {
let pk = p256::PublicKey::from_sec1_bytes(bytes).map_err(|e| {
CryptoError::KeyAgreement(format!("invalid P-256 public key: {e}"))
})?;
Ok(PublicKeyAgreement::P256(pk))
}
Curve::K256 => {
let pk = k256::PublicKey::from_sec1_bytes(bytes).map_err(|e| {
CryptoError::KeyAgreement(format!("invalid K-256 public key: {e}"))
})?;
Ok(PublicKeyAgreement::K256(pk))
}
Curve::P384 => {
let pk = p384::PublicKey::from_sec1_bytes(bytes).map_err(|e| {
CryptoError::KeyAgreement(format!("invalid P-384 public key: {e}"))
})?;
Ok(PublicKeyAgreement::P384(pk))
}
Curve::P521 => {
let pk = p521::PublicKey::from_sec1_bytes(bytes).map_err(|e| {
CryptoError::KeyAgreement(format!("invalid P-521 public key: {e}"))
})?;
Ok(PublicKeyAgreement::P521(pk))
}
}
}
pub fn from_jwk(jwk: &Value) -> Result<Self, CryptoError> {
let crv = jwk["crv"]
.as_str()
.ok_or_else(|| CryptoError::KeyAgreement("missing crv in JWK".into()))?;
match crv {
"X25519" => {
let x = jwk["x"]
.as_str()
.ok_or_else(|| CryptoError::KeyAgreement("missing x in X25519 JWK".into()))?;
let bytes = URL_SAFE_NO_PAD
.decode(x)
.map_err(|e| CryptoError::KeyAgreement(format!("invalid x: {e}")))?;
let arr: [u8; 32] = bytes
.try_into()
.map_err(|_| CryptoError::KeyAgreement("X25519 key must be 32 bytes".into()))?;
Ok(PublicKeyAgreement::X25519(arr))
}
"P-256" => {
let point = ec_point_from_jwk(jwk)?;
let pk = p256::PublicKey::from_sec1_bytes(&point)
.map_err(|e| CryptoError::KeyAgreement(format!("invalid P-256 key: {e}")))?;
Ok(PublicKeyAgreement::P256(pk))
}
"secp256k1" => {
let point = ec_point_from_jwk(jwk)?;
let pk = k256::PublicKey::from_sec1_bytes(&point)
.map_err(|e| CryptoError::KeyAgreement(format!("invalid K-256 key: {e}")))?;
Ok(PublicKeyAgreement::K256(pk))
}
"P-384" => {
let point = ec_point_from_jwk(jwk)?;
let pk = p384::PublicKey::from_sec1_bytes(&point)
.map_err(|e| CryptoError::KeyAgreement(format!("invalid P-384 key: {e}")))?;
Ok(PublicKeyAgreement::P384(pk))
}
"P-521" => {
let point = ec_point_from_jwk(jwk)?;
let pk = p521::PublicKey::from_sec1_bytes(&point)
.map_err(|e| CryptoError::KeyAgreement(format!("invalid P-521 key: {e}")))?;
Ok(PublicKeyAgreement::P521(pk))
}
other => Err(CryptoError::UnsupportedKeyType(format!(
"unsupported key-agreement curve: {other}"
))),
}
}
}
fn ec_point_from_jwk(jwk: &Value) -> Result<Vec<u8>, CryptoError> {
let x = jwk["x"]
.as_str()
.ok_or_else(|| CryptoError::KeyAgreement("missing x in EC JWK".into()))?;
let y = jwk["y"]
.as_str()
.ok_or_else(|| CryptoError::KeyAgreement("missing y in EC JWK".into()))?;
let x_bytes = URL_SAFE_NO_PAD
.decode(x)
.map_err(|e| CryptoError::KeyAgreement(format!("invalid x: {e}")))?;
let y_bytes = URL_SAFE_NO_PAD
.decode(y)
.map_err(|e| CryptoError::KeyAgreement(format!("invalid y: {e}")))?;
let mut point = Vec::with_capacity(1 + x_bytes.len() + y_bytes.len());
point.push(0x04);
point.extend_from_slice(&x_bytes);
point.extend_from_slice(&y_bytes);
Ok(point)
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
#[non_exhaustive]
pub enum PrivateKeyAgreement {
X25519(#[zeroize(skip)] x25519_dalek::StaticSecret),
P256(#[zeroize(skip)] p256::SecretKey),
K256(#[zeroize(skip)] k256::SecretKey),
P384(#[zeroize(skip)] p384::SecretKey),
P521(#[zeroize(skip)] p521::SecretKey),
}
impl std::fmt::Debug for PrivateKeyAgreement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PrivateKeyAgreement::X25519(_) => write!(f, "PrivateKeyAgreement::X25519([REDACTED])"),
PrivateKeyAgreement::P256(_) => write!(f, "PrivateKeyAgreement::P256([REDACTED])"),
PrivateKeyAgreement::K256(_) => write!(f, "PrivateKeyAgreement::K256([REDACTED])"),
PrivateKeyAgreement::P384(_) => write!(f, "PrivateKeyAgreement::P384([REDACTED])"),
PrivateKeyAgreement::P521(_) => write!(f, "PrivateKeyAgreement::P521([REDACTED])"),
}
}
}
impl PrivateKeyAgreement {
pub fn from_raw_bytes(curve: Curve, bytes: &[u8]) -> Result<Self, CryptoError> {
match curve {
Curve::X25519 => {
let arr: [u8; 32] = bytes.try_into().map_err(|_| {
CryptoError::KeyAgreement("X25519 private key must be 32 bytes".into())
})?;
Ok(PrivateKeyAgreement::X25519(
x25519_dalek::StaticSecret::from(arr),
))
}
Curve::P256 => {
let sk = p256::SecretKey::from_slice(bytes).map_err(|e| {
CryptoError::KeyAgreement(format!("invalid P-256 private key: {e}"))
})?;
Ok(PrivateKeyAgreement::P256(sk))
}
Curve::K256 => {
let sk = k256::SecretKey::from_slice(bytes).map_err(|e| {
CryptoError::KeyAgreement(format!("invalid K-256 private key: {e}"))
})?;
Ok(PrivateKeyAgreement::K256(sk))
}
Curve::P384 => {
let sk = p384::SecretKey::from_slice(bytes).map_err(|e| {
CryptoError::KeyAgreement(format!("invalid P-384 private key: {e}"))
})?;
Ok(PrivateKeyAgreement::P384(sk))
}
Curve::P521 => {
let sk = p521::SecretKey::from_slice(bytes).map_err(|e| {
CryptoError::KeyAgreement(format!("invalid P-521 private key: {e}"))
})?;
Ok(PrivateKeyAgreement::P521(sk))
}
}
}
pub fn generate(curve: Curve) -> Self {
match curve {
Curve::X25519 => {
PrivateKeyAgreement::X25519(x25519_dalek::StaticSecret::random_from_rng(OsRng))
}
Curve::P256 => PrivateKeyAgreement::P256(p256::SecretKey::random(&mut OsRng)),
Curve::K256 => PrivateKeyAgreement::K256(k256::SecretKey::random(&mut OsRng)),
Curve::P384 => PrivateKeyAgreement::P384(p384::SecretKey::random(&mut OsRng)),
Curve::P521 => PrivateKeyAgreement::P521(p521::SecretKey::random(&mut OsRng)),
}
}
pub fn public_key(&self) -> PublicKeyAgreement {
match self {
PrivateKeyAgreement::X25519(sk) => {
PublicKeyAgreement::X25519(x25519_dalek::PublicKey::from(sk).to_bytes())
}
PrivateKeyAgreement::P256(sk) => PublicKeyAgreement::P256(sk.public_key()),
PrivateKeyAgreement::K256(sk) => PublicKeyAgreement::K256(sk.public_key()),
PrivateKeyAgreement::P384(sk) => PublicKeyAgreement::P384(sk.public_key()),
PrivateKeyAgreement::P521(sk) => PublicKeyAgreement::P521(sk.public_key()),
}
}
pub fn curve(&self) -> Curve {
match self {
PrivateKeyAgreement::X25519(_) => Curve::X25519,
PrivateKeyAgreement::P256(_) => Curve::P256,
PrivateKeyAgreement::K256(_) => Curve::K256,
PrivateKeyAgreement::P384(_) => Curve::P384,
PrivateKeyAgreement::P521(_) => Curve::P521,
}
}
pub fn diffie_hellman(
&self,
their_public: &PublicKeyAgreement,
) -> Result<Vec<u8>, CryptoError> {
match (self, their_public) {
(PrivateKeyAgreement::X25519(sk), PublicKeyAgreement::X25519(pk)) => {
let pk = x25519_dalek::PublicKey::from(*pk);
Ok(sk.diffie_hellman(&pk).as_bytes().to_vec())
}
(PrivateKeyAgreement::P256(sk), PublicKeyAgreement::P256(pk)) => {
use p256::ecdh::diffie_hellman;
let shared = diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine());
Ok(shared.raw_secret_bytes().to_vec())
}
(PrivateKeyAgreement::K256(sk), PublicKeyAgreement::K256(pk)) => {
use k256::ecdh::diffie_hellman;
let shared = diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine());
Ok(shared.raw_secret_bytes().to_vec())
}
(PrivateKeyAgreement::P384(sk), PublicKeyAgreement::P384(pk)) => {
use p384::ecdh::diffie_hellman;
let shared = diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine());
Ok(shared.raw_secret_bytes().to_vec())
}
(PrivateKeyAgreement::P521(sk), PublicKeyAgreement::P521(pk)) => {
use p521::ecdh::diffie_hellman;
let shared = diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine());
Ok(shared.raw_secret_bytes().to_vec())
}
_ => Err(CryptoError::KeyAgreement(
"curve mismatch between private and public keys".into(),
)),
}
}
}
pub struct EphemeralKeyPair {
pub private: PrivateKeyAgreement,
pub public: PublicKeyAgreement,
}
impl EphemeralKeyPair {
pub fn generate(curve: Curve) -> Self {
let private = PrivateKeyAgreement::generate(curve);
let public = private.public_key();
Self { private, public }
}
}