mod compact;
mod flattened;
pub use compact::Compact;
pub use flattened::{Signable, SignedData};
use crate::errors::Error;
use crate::jwa::SignatureAlgorithm;
use crate::jwk;
use crate::{CompactJson, Empty};
use num_bigint::BigUint;
use ring::signature;
use serde::{self, de::DeserializeOwned, Deserialize, Serialize};
use std::sync::Arc;
#[derive(Clone)]
pub enum Secret {
None,
Bytes(Vec<u8>),
RsaKeyPair(Arc<signature::RsaKeyPair>),
EcdsaKeyPair(Arc<signature::EcdsaKeyPair>),
PublicKey(Vec<u8>),
RSAModulusExponent {
n: BigUint,
e: BigUint,
},
}
impl Secret {
fn read_bytes(path: &str) -> Result<Vec<u8>, Error> {
use std::fs::File;
use std::io::prelude::*;
let mut file = File::open(path)?;
let metadata = file.metadata()?;
let mut bytes: Vec<u8> = Vec::with_capacity(metadata.len() as usize);
let _ = file.read_to_end(&mut bytes)?;
Ok(bytes)
}
pub fn bytes_from_str(secret: &str) -> Self {
Secret::Bytes(secret.to_string().into_bytes())
}
pub fn rsa_keypair_from_file(path: &str) -> Result<Self, Error> {
let der = Self::read_bytes(path)?;
let key_pair = signature::RsaKeyPair::from_der(der.as_slice())?;
Ok(Secret::RsaKeyPair(Arc::new(key_pair)))
}
pub fn ecdsa_keypair_from_file(
algorithm: SignatureAlgorithm,
path: &str,
) -> Result<Self, Error> {
let der = Self::read_bytes(path)?;
let ring_algorithm = match algorithm {
SignatureAlgorithm::ES256 => &signature::ECDSA_P256_SHA256_FIXED_SIGNING,
SignatureAlgorithm::ES384 => &signature::ECDSA_P384_SHA384_FIXED_SIGNING,
_ => return Err(Error::UnsupportedOperation),
};
let key_pair = signature::EcdsaKeyPair::from_pkcs8(
ring_algorithm,
der.as_slice(),
&ring::rand::SystemRandom::new(),
)?;
Ok(Secret::EcdsaKeyPair(Arc::new(key_pair)))
}
pub fn public_key_from_file(path: &str) -> Result<Self, Error> {
let der = Self::read_bytes(path)?;
Ok(Secret::PublicKey(der.to_vec()))
}
}
impl From<jwk::RSAKeyParameters> for Secret {
fn from(rsa: jwk::RSAKeyParameters) -> Self {
rsa.jws_public_key_secret()
}
}
#[derive(Debug, Eq, PartialEq, Clone, Default, Serialize, Deserialize)]
pub struct Header<T> {
#[serde(flatten)]
pub registered: RegisteredHeader,
#[serde(flatten)]
pub private: T,
}
impl<T: Serialize + DeserializeOwned> CompactJson for Header<T> {}
impl Header<Empty> {
pub fn from_registered_header(registered: RegisteredHeader) -> Self {
Self {
registered,
..Default::default()
}
}
}
impl From<RegisteredHeader> for Header<Empty> {
fn from(registered: RegisteredHeader) -> Self {
Self::from_registered_header(registered)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct RegisteredHeader {
#[serde(rename = "alg")]
pub algorithm: SignatureAlgorithm,
#[serde(rename = "typ", skip_serializing_if = "Option::is_none")]
pub media_type: Option<String>,
#[serde(rename = "cty", skip_serializing_if = "Option::is_none")]
pub content_type: Option<String>,
#[serde(rename = "jku", skip_serializing_if = "Option::is_none")]
pub web_key_url: Option<String>,
#[serde(rename = "jwk", skip_serializing_if = "Option::is_none")]
pub web_key: Option<jwk::JWK<Empty>>,
#[serde(rename = "kid", skip_serializing_if = "Option::is_none")]
pub key_id: Option<String>,
#[serde(rename = "x5u", skip_serializing_if = "Option::is_none")]
pub x509_url: Option<String>,
#[serde(rename = "x5c", skip_serializing_if = "Option::is_none")]
pub x509_chain: Option<Vec<String>>,
#[serde(rename = "x5t", skip_serializing_if = "Option::is_none")]
pub x509_fingerprint: Option<String>,
#[serde(rename = "crit", skip_serializing_if = "Option::is_none")]
pub critical: Option<Vec<String>>,
}
impl Default for RegisteredHeader {
fn default() -> RegisteredHeader {
RegisteredHeader {
algorithm: SignatureAlgorithm::default(),
media_type: Some("JWT".to_string()),
content_type: None,
web_key_url: None,
web_key: None,
key_id: None,
x509_url: None,
x509_chain: None,
x509_fingerprint: None,
critical: None,
}
}
}
#[cfg(test)]
mod tests {
use super::RegisteredHeader;
#[test]
fn header_serialization_round_trip_no_optional() {
let expected = RegisteredHeader::default();
let expected_json = r#"{"alg":"HS256","typ":"JWT"}"#;
let encoded = not_err!(serde_json::to_string(&expected));
assert_eq!(expected_json, encoded);
let decoded: RegisteredHeader = not_err!(serde_json::from_str(&encoded));
assert_eq!(decoded, expected);
}
#[test]
fn header_serialization_round_trip_with_optional() {
let expected = RegisteredHeader {
key_id: Some("kid".to_string()),
..Default::default()
};
let expected_json = r#"{"alg":"HS256","typ":"JWT","kid":"kid"}"#;
let encoded = not_err!(serde_json::to_string(&expected));
assert_eq!(expected_json, encoded);
let decoded: RegisteredHeader = not_err!(serde_json::from_str(&encoded));
assert_eq!(decoded, expected);
}
}