use std::fmt;
use std::time::Duration;
use aead::generic_array::ArrayLength;
use base64ct::{Base64Url, Base64UrlUnpadded, Encoding};
use ecdsa::signature::Verifier;
use ecdsa::{Signature, SignatureSize, VerifyingKey};
use elliptic_curve::sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, ToEncodedPoint};
use elliptic_curve::zeroize::Zeroizing;
#[cfg(test)]
use elliptic_curve::SecretKey;
use elliptic_curve::{
AffinePoint, Curve, CurveArithmetic, FieldBytes, FieldBytesSize, JwkParameters, PrimeCurve,
PublicKey,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
#[cfg(test)]
use zeroize::Zeroize;
use crate::key_exchange::{create_enc_key, recover_enc_key};
use crate::util::{b64_to_bytes, b64_to_str};
use crate::{EncryptionKey, Error, Result, TangClient};
#[derive(Clone, Deserialize)]
pub struct Advertisment {
#[serde(deserialize_with = "b64_to_str")]
protected: String,
#[serde(deserialize_with = "b64_to_str")]
payload: String,
#[serde(deserialize_with = "b64_to_bytes")]
signature: Vec<u8>,
}
impl Advertisment {
fn validate(&self, jwks: &JwkSet, thumbprint: Option<&str>) -> Result<Box<str>> {
let (verify_jwk, thp) = if let Some(thp) = thumbprint {
(jwks.get_key_by_id(thp)?, Box::from(thp))
} else {
let verify_jwk = jwks.get_key_by_op("verify")?;
(verify_jwk, verify_jwk.make_thumbprint(ThpHashAlg::Sha256))
};
let payload_b64_len = Base64UrlUnpadded::encoded_len(self.payload.as_bytes());
let protected_b64_len = Base64UrlUnpadded::encoded_len(self.protected.as_bytes());
let mut to_verify = vec![b'.'; payload_b64_len + 1 + protected_b64_len];
Base64UrlUnpadded::encode(
self.protected.as_bytes(),
&mut to_verify[..protected_b64_len],
)
.unwrap();
Base64UrlUnpadded::encode(
self.payload.as_bytes(),
&mut to_verify[(protected_b64_len + 1)..],
)
.unwrap();
verify_jwk.verify(&to_verify, &self.signature)?;
Ok(thp)
}
pub fn validate_into_keys(self, thumbprint: Option<&str>) -> Result<(JwkSet, Box<str>)> {
let jwks: JwkSet = serde_json::from_str(&self.payload)?;
let thp = self.validate(&jwks, thumbprint)?;
Ok((jwks, thp))
}
}
impl fmt::Debug for Advertisment {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn json_field(s: &str) -> Box<dyn fmt::Debug + '_> {
match serde_json::from_str::<Value>(s) {
Ok(v) => Box::new(v),
Err(_) => Box::new(s),
}
}
f.debug_struct("Advertisment")
.field("payload", &json_field(&self.payload))
.field("protected", &json_field(&self.protected))
.field(
"signature",
&Base64UrlUnpadded::encode_string(&self.signature),
)
.finish()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Jwk {
#[serde(flatten)]
pub inner: JwkInner,
#[serde(skip_serializing_if = "Option::is_none")]
pub key_ops: Option<Vec<Box<str>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub alg: Option<Box<str>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "kty", rename_all = "UPPERCASE")]
pub enum JwkInner {
Ec(EcJwk),
Rsa(RsaJwk),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EcJwk {
pub crv: Box<str>,
pub x: Box<str>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y: Option<Box<str>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub d: Option<Zeroizing<Box<str>>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RsaJwk {
pub e: Box<str>,
pub n: Box<str>,
}
impl Jwk {
fn make_thumbprint(&self, alg: ThpHashAlg) -> Box<str> {
match &self.inner {
JwkInner::Ec(ec_key) => ec_key.make_thumbprint(alg),
JwkInner::Rsa(rsa_key) => rsa_key.make_thumbprint(alg),
}
}
pub(crate) fn as_ec(&self) -> Result<&EcJwk> {
match &self.inner {
JwkInner::Ec(key) => Ok(key),
JwkInner::Rsa(_) => Err(Error::Algorithm("RSA".into())),
}
}
pub(crate) fn verify(&self, message: &[u8], signature: &[u8]) -> Result<()> {
match &self.inner {
JwkInner::Ec(v) => v.verify(message, signature),
JwkInner::Rsa(_) => Err(Error::Algorithm("RSA".into())),
}
}
}
impl fmt::Display for Jwk {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut to_fmt = self.clone();
if let JwkInner::Ec(EcJwk {
d: Some(ref mut val),
..
}) = to_fmt.inner
{
*val = Zeroizing::new("****".into());
};
f.write_str(&serde_json::to_string(&to_fmt).unwrap())
}
}
fn encode_base64url_fe<C: Curve>(field: &FieldBytes<C>) -> Box<str> {
Base64Url::encode_string(field).into()
}
fn decode_base64url_fe<C: Curve>(s: &str) -> Result<FieldBytes<C>> {
let mut result = FieldBytes::<C>::default();
Base64Url::decode(s, &mut result).map_err(|_| Error::EllipitcCurve)?;
Ok(result)
}
impl EcJwk {
pub(crate) fn to_pub<C>(&self) -> Result<PublicKey<C>>
where
C: CurveArithmetic + JwkParameters,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldBytesSize<C>: ModulusSize,
{
assert_eq!(self.crv.as_ref(), C::CRV);
let Some(ref y) = self.y else {
return Err(Error::InvalidPublicKey(Jwk::from(self.clone()).into()));
};
let x = decode_base64url_fe::<C>(&self.x)?;
let y = decode_base64url_fe::<C>(y)?;
let affine = EncodedPoint::<C>::from_affine_coordinates(&x, &y, false);
PublicKey::from_sec1_bytes(affine.as_bytes()).map_err(Into::into)
}
pub(crate) fn from_pub<C>(key: &PublicKey<C>) -> Self
where
C: CurveArithmetic + JwkParameters,
AffinePoint<C>: ToEncodedPoint<C>,
FieldBytesSize<C>: ModulusSize,
{
let point = key.as_affine().to_encoded_point(false);
let x = encode_base64url_fe::<C>(point.x().expect("unexpected identity point"));
let y = encode_base64url_fe::<C>(point.y().expect("unexpected identity point"));
Self {
crv: C::CRV.into(),
x,
y: Some(y),
d: None,
}
}
#[cfg(test)]
pub(crate) fn to_priv<C>(&self) -> Result<SecretKey<C>>
where
C: CurveArithmetic,
{
let Some(d) = &self.d else {
panic!("expected private key but got public")
};
let mut d_bytes = decode_base64url_fe::<C>(d.as_ref())?;
let result = SecretKey::<C>::from_slice(&d_bytes)?;
d_bytes.zeroize();
Ok(result)
}
pub(crate) fn get_curve(&self) -> Result<JwkCurve> {
match self.crv.as_ref() {
"P-256" => Ok(JwkCurve::P256),
"P-284" => Ok(JwkCurve::P384),
"P-521" => Ok(JwkCurve::P521),
other => Err(Error::Algorithm(other.into())),
}
}
pub(crate) fn verify(&self, message: &[u8], signature: &[u8]) -> Result<()> {
match self.get_curve()? {
JwkCurve::P256 => self.verify_inner::<p256::NistP256>(message, signature),
JwkCurve::P384 => self.verify_inner::<p384::NistP384>(message, signature),
JwkCurve::P521 => self.verify_p521(message, signature),
}
}
pub(crate) fn make_thumbprint(&self, alg: ThpHashAlg) -> Box<str> {
let to_enc = json! {{
"crv": &self.crv,
"kty": "EC",
"x": &self.x,
"y": &self.y
}};
alg.hash_data_to_string(to_enc.to_string().as_bytes())
}
fn verify_inner<C>(&self, msg: &[u8], sig: &[u8]) -> Result<()>
where
C: CurveArithmetic + PrimeCurve + JwkParameters,
VerifyingKey<C>: Verifier<Signature<C>>,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldBytesSize<C>: ModulusSize,
SignatureSize<C>: ArrayLength<u8>,
{
let pubkey = self.to_pub::<C>()?;
let verify_key = VerifyingKey::from_affine(*pubkey.as_affine())?;
let signature = Signature::from_slice(sig)?;
verify_key
.verify(msg, &signature)
.map_err(|_| Error::FailedVerification)
}
fn verify_p521(&self, msg: &[u8], sig: &[u8]) -> Result<()> {
use p521::ecdsa::{Signature, VerifyingKey};
let pubkey = self.to_pub::<p521::NistP521>()?;
let verify_key = VerifyingKey::from_affine(*pubkey.as_affine())?;
let signature = Signature::from_slice(sig)?;
verify_key
.verify(msg, &signature)
.map_err(|_| Error::FailedVerification)
}
}
impl From<EcJwk> for Jwk {
fn from(value: EcJwk) -> Self {
Jwk {
inner: JwkInner::Ec(value),
key_ops: None,
alg: None,
}
}
}
impl TryFrom<Jwk> for EcJwk {
type Error = Error;
fn try_from(value: Jwk) -> std::result::Result<Self, Self::Error> {
match value.inner {
JwkInner::Ec(ec_key) => Ok(ec_key),
JwkInner::Rsa(_) => Err(Error::Algorithm("RSA".into())),
}
}
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum JwkCurve {
P256,
P384,
P521,
}
impl RsaJwk {
fn make_thumbprint(&self, alg: ThpHashAlg) -> Box<str> {
let to_enc = json! {{ "e": &self.e, "kty": "RSA", "n": &self.n }};
alg.hash_data_to_string(to_enc.to_string().as_bytes())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct JwkSet {
keys: Vec<Jwk>,
}
impl JwkSet {
fn get_key_by_op(&self, op_name: &str) -> Result<&Jwk> {
self.keys
.iter()
.find(|key| {
key.key_ops.as_ref().map_or(false, |ops| {
ops.iter().any(|op| op.eq_ignore_ascii_case(op_name))
})
})
.ok_or(Error::MissingKeyOp(op_name.into()))
}
pub(crate) fn get_key_by_id(&self, kid: &str) -> Result<&Jwk> {
for key in &self.keys {
let thp_sha256 = key.make_thumbprint(ThpHashAlg::Sha256);
if thp_sha256.as_ref() == kid {
return Ok(key);
}
let thp_sha1 = key.make_thumbprint(ThpHashAlg::Sha1);
if thp_sha1.as_ref() == kid {
return Ok(key);
}
}
Err(Error::MissingPublicKey)
}
pub(crate) fn make_tang_enc_key<const N: usize>(
&self,
url: &str,
signing_thumbprint: Box<str>,
) -> Result<ProvisionedData<N>> {
let derive_jwk = self.get_key_by_op("deriveKey")?.clone();
let derive_jwk = derive_jwk.as_ec()?;
let (epk, encryption_key) = create_enc_key(derive_jwk)?;
let clevis = ClevisParams {
pin: "tang".into(),
tang: TangParams {
adv: self.clone(),
url: url.into(),
},
};
let meta = KeyMeta {
alg: "ECDH-ES".into(),
clevis,
enc: None,
epk: epk.into(),
kid: derive_jwk.make_thumbprint(ThpHashAlg::Sha256),
};
Ok(ProvisionedData {
encryption_key,
signing_thumbprint,
meta,
})
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum ThpHashAlg {
Sha1,
Sha256,
}
impl ThpHashAlg {
fn hash_data_to_string(self, data: &[u8]) -> Box<str> {
match self {
ThpHashAlg::Sha1 => {
let mut hasher = sha1::Sha1::new();
hasher.update(data);
Base64UrlUnpadded::encode_string(&hasher.finalize())
}
ThpHashAlg::Sha256 => {
let mut hasher = Sha256::new();
hasher.update(data);
Base64UrlUnpadded::encode_string(&hasher.finalize())
}
}
.into_boxed_str()
}
}
pub struct ProvisionedData<const KEYBYTES: usize> {
pub encryption_key: EncryptionKey<KEYBYTES>,
pub signing_thumbprint: Box<str>,
pub meta: KeyMeta,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct KeyMeta {
alg: Box<str>,
clevis: ClevisParams,
enc: Option<Box<str>>,
epk: Jwk,
kid: Box<str>,
}
#[derive(Debug, Deserialize, Serialize)]
struct ClevisParams {
pin: Box<str>,
tang: TangParams,
}
#[derive(Debug, Deserialize, Serialize)]
struct TangParams {
adv: JwkSet,
url: Box<str>,
}
impl KeyMeta {
pub fn to_json(&self) -> String {
serde_json::to_string(self).expect("serialization failure")
}
pub fn from_json(val: &str) -> Result<Self> {
serde_json::from_str(val).map_err(Into::into)
}
pub fn from_json_bytes(val: &[u8]) -> Result<Self> {
serde_json::from_slice(val).map_err(Into::into)
}
pub fn client(&self, timeout: Option<Duration>) -> TangClient {
TangClient::new(&self.clevis.tang.url, timeout)
}
pub fn url(&self) -> &str {
&self.clevis.tang.url
}
pub(crate) fn recover_key<const N: usize>(
&self,
server_key_exchange: impl FnOnce(&str, &Jwk) -> Result<Jwk>,
) -> Result<EncryptionKey<N>> {
let c_pub_jwk = &self.epk.as_ec()?;
let kid = &self.kid;
let s_pub_jwk = self.clevis.tang.adv.get_key_by_id(kid)?.as_ec()?;
recover_enc_key(c_pub_jwk, s_pub_jwk, |x_pub_jwk| {
server_key_exchange(kid, x_pub_jwk)
})
}
}
#[cfg(test)]
#[path = "jose_tests.rs"]
mod tests;