use crate::hash::HashAlgorithm;
use crate::jose::jwk::Jwk;
use crate::key::{PrivateKey, PublicKey};
use crate::signature::{SignatureAlgorithm, SignatureError};
use base64::DecodeError;
use core::convert::TryFrom;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum JwsError {
#[error("RSA error: {context}")]
Rsa { context: String },
#[error("JSON error: {source}")]
Json { source: serde_json::Error },
#[error("signature error: {source}")]
Signature { source: SignatureError },
#[error("input isn't a valid token string: {input}")]
InvalidEncoding { input: String },
#[error("couldn't decode base64: {source}")]
Base64Decoding { source: DecodeError },
#[error("input isn't valid utf8: {source}, input: {input:?}")]
InvalidUtf8 {
source: std::string::FromUtf8Error,
input: Vec<u8>,
},
}
impl From<rsa::errors::Error> for JwsError {
fn from(e: rsa::errors::Error) -> Self {
Self::Rsa { context: e.to_string() }
}
}
impl From<serde_json::Error> for JwsError {
fn from(e: serde_json::Error) -> Self {
Self::Json { source: e }
}
}
impl From<SignatureError> for JwsError {
fn from(e: SignatureError) -> Self {
Self::Signature { source: e }
}
}
impl From<DecodeError> for JwsError {
fn from(e: DecodeError) -> Self {
Self::Base64Decoding { source: e }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum JwsAlg {
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
ES256,
ES384,
ES512,
PS256,
PS384,
PS512,
}
impl TryFrom<SignatureAlgorithm> for JwsAlg {
type Error = SignatureError;
fn try_from(v: SignatureAlgorithm) -> Result<Self, Self::Error> {
match v {
SignatureAlgorithm::RsaPkcs1v15(HashAlgorithm::SHA2_256) => Ok(Self::RS256),
SignatureAlgorithm::RsaPkcs1v15(HashAlgorithm::SHA2_384) => Ok(Self::RS384),
SignatureAlgorithm::RsaPkcs1v15(HashAlgorithm::SHA2_512) => Ok(Self::RS512),
unsupported => Err(SignatureError::UnsupportedAlgorithm {
algorithm: format!("{:?}", unsupported),
}),
}
}
}
impl TryFrom<JwsAlg> for SignatureAlgorithm {
type Error = SignatureError;
fn try_from(v: JwsAlg) -> Result<Self, Self::Error> {
match v {
JwsAlg::RS256 => Ok(SignatureAlgorithm::RsaPkcs1v15(HashAlgorithm::SHA2_256)),
JwsAlg::RS384 => Ok(SignatureAlgorithm::RsaPkcs1v15(HashAlgorithm::SHA2_384)),
JwsAlg::RS512 => Ok(SignatureAlgorithm::RsaPkcs1v15(HashAlgorithm::SHA2_512)),
unsupported => Err(SignatureError::UnsupportedAlgorithm {
algorithm: format!("{:?}", unsupported),
}),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct JwsHeader {
pub alg: JwsAlg,
#[serde(skip_serializing_if = "Option::is_none")]
pub jku: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwk: Option<Jwk>,
#[serde(skip_serializing_if = "Option::is_none")]
pub typ: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cty: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub kid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x5u: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x5c: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x5t: Option<String>,
#[serde(rename = "x5t#S256", alias = "x5t#s256", skip_serializing_if = "Option::is_none")]
pub x5t_s256: Option<String>,
}
impl JwsHeader {
pub fn new(alg: JwsAlg) -> Self {
Self {
alg,
jku: None,
jwk: None,
typ: None,
cty: None,
kid: None,
x5u: None,
x5c: None,
x5t: None,
x5t_s256: None,
}
}
}
#[derive(Debug, Clone)]
pub struct Jws {
pub header: JwsHeader,
pub payload: Vec<u8>,
}
impl Jws {
pub fn new(alg: JwsAlg, payload: Vec<u8>) -> Self {
Self {
header: JwsHeader::new(alg),
payload,
}
}
pub fn check_signature(&self, encoded_token: &str, public_key: &PublicKey) -> Result<(), JwsError> {
let last_dot_idx = encoded_token.rfind('.').ok_or_else(|| JwsError::InvalidEncoding {
input: encoded_token.to_owned(),
})?;
if encoded_token.ends_with('.') {
return Err(JwsError::InvalidEncoding {
input: encoded_token.to_owned(),
});
}
let signature = base64::decode_config(&encoded_token[last_dot_idx + 1..], base64::URL_SAFE_NO_PAD)?;
let signature_algo = SignatureAlgorithm::try_from(self.header.alg)?;
signature_algo.verify(public_key, &encoded_token[..last_dot_idx].as_bytes(), &signature)?;
Ok(())
}
pub fn encode(&self, private_key: &PrivateKey) -> Result<String, JwsError> {
let header_base64 = base64::encode_config(&serde_json::to_vec(&self.header)?, base64::URL_SAFE_NO_PAD);
let payload_base64 = base64::encode_config(&self.payload, base64::URL_SAFE_NO_PAD);
let header_and_payload = [header_base64, payload_base64].join(".");
let signature_algo = SignatureAlgorithm::try_from(self.header.alg)?;
let signature = signature_algo.sign(header_and_payload.as_bytes(), private_key)?;
let signature_base64 = base64::encode_config(&signature, base64::URL_SAFE_NO_PAD);
Ok([header_and_payload, signature_base64].join("."))
}
pub fn decode(encoded_token: &str, public_key: &PublicKey) -> Result<Self, JwsError> {
decode_impl(encoded_token, Some(public_key))
}
pub fn decode_without_validation(encoded_token: &str) -> Result<Self, JwsError> {
decode_impl(encoded_token, None)
}
}
fn decode_impl(encoded_token: &str, public_key: Option<&PublicKey>) -> Result<Jws, JwsError> {
let first_dot_idx = encoded_token.find('.').ok_or_else(|| JwsError::InvalidEncoding {
input: encoded_token.to_owned(),
})?;
let last_dot_idx = encoded_token.rfind('.').ok_or_else(|| JwsError::InvalidEncoding {
input: encoded_token.to_owned(),
})?;
if first_dot_idx == last_dot_idx || encoded_token.starts_with('.') || encoded_token.ends_with('.') {
return Err(JwsError::InvalidEncoding {
input: encoded_token.to_owned(),
});
}
let header_json = base64::decode_config(&encoded_token[..first_dot_idx], base64::URL_SAFE_NO_PAD)?;
let header = serde_json::from_slice::<JwsHeader>(&header_json)?;
if let Some(public_key) = public_key {
let signature = base64::decode_config(&encoded_token[last_dot_idx + 1..], base64::URL_SAFE_NO_PAD)?;
let signature_algo = SignatureAlgorithm::try_from(header.alg)?;
signature_algo.verify(public_key, &encoded_token[..last_dot_idx].as_bytes(), &signature)?;
}
let payload = base64::decode_config(&encoded_token[first_dot_idx + 1..last_dot_idx], base64::URL_SAFE_NO_PAD)?;
Ok(Jws { header, payload })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pem::Pem;
const PAYLOAD: &str = r#"{"sub":"1234567890","name":"John Doe","admin":true,"iat":1516239022}"#;
fn get_private_key_1() -> PrivateKey {
let pk_pem = crate::test_files::RSA_2048_PK_1.parse::<Pem>().unwrap();
PrivateKey::from_pem(&pk_pem).unwrap()
}
fn get_private_key_2() -> PrivateKey {
let pk_pem = crate::test_files::RSA_2048_PK_7.parse::<Pem>().unwrap();
PrivateKey::from_pem(&pk_pem).unwrap()
}
#[test]
fn encode_rsa_sha256() {
let jwt = Jws {
header: JwsHeader {
typ: Some(String::from("JWT")),
..JwsHeader::new(JwsAlg::RS256)
},
payload: PAYLOAD.as_bytes().to_vec(),
};
let encoded = jwt.encode(&get_private_key_1()).unwrap();
assert_eq!(encoded, crate::test_files::JOSE_JWT_SIG_EXAMPLE);
}
#[test]
fn decode_rsa_sha256() {
let public_key = get_private_key_1().to_public_key();
let jwt = Jws::decode(crate::test_files::JOSE_JWT_SIG_EXAMPLE, &public_key).unwrap();
assert_eq!(jwt.payload.as_slice(), PAYLOAD.as_bytes());
}
#[test]
fn decode_rsa_sha256_delayed_signature_check() {
let jws = Jws::decode_without_validation(crate::test_files::JOSE_JWT_SIG_EXAMPLE).unwrap();
println!("{}", String::from_utf8_lossy(&jws.payload));
assert_eq!(jws.payload.as_slice(), PAYLOAD.as_bytes());
let public_key = get_private_key_2().to_public_key();
let err = jws
.check_signature(crate::test_files::JOSE_JWT_SIG_EXAMPLE, &public_key)
.err()
.unwrap();
assert_eq!(err.to_string(), "signature error: invalid signature");
}
#[test]
fn decode_rsa_sha256_invalid_signature_err() {
let public_key = get_private_key_2().to_public_key();
let err = Jws::decode(crate::test_files::JOSE_JWT_SIG_EXAMPLE, &public_key)
.err()
.unwrap();
assert_eq!(err.to_string(), "signature error: invalid signature");
}
#[test]
fn decode_invalid_base64_err() {
let public_key = get_private_key_1().to_public_key();
let err = Jws::decode("aieoè~†.tésp.à", &public_key).err().unwrap();
assert_eq!(err.to_string(), "couldn\'t decode base64: Invalid byte 195, offset 4.");
}
#[test]
fn decode_invalid_json_err() {
let public_key = get_private_key_1().to_public_key();
let err = Jws::decode("abc.abc.abc", &public_key).err().unwrap();
assert_eq!(err.to_string(), "JSON error: expected value at line 1 column 1");
let err = Jws::decode("eyAiYWxnIjogIkhTMjU2IH0K.abc.abc", &public_key)
.err()
.unwrap();
assert_eq!(
err.to_string(),
"JSON error: control character (\\u0000-\\u001F) \
found while parsing a string at line 2 column 0"
);
}
#[test]
fn decode_invalid_encoding_err() {
let public_key = get_private_key_1().to_public_key();
let err = Jws::decode(".abc.abc", &public_key).err().unwrap();
assert_eq!(err.to_string(), "input isn\'t a valid token string: .abc.abc");
let err = Jws::decode("abc.abc.", &public_key).err().unwrap();
assert_eq!(err.to_string(), "input isn\'t a valid token string: abc.abc.");
let err = Jws::decode("abc.abc", &public_key).err().unwrap();
assert_eq!(err.to_string(), "input isn\'t a valid token string: abc.abc");
let err = Jws::decode("abc", &public_key).err().unwrap();
assert_eq!(err.to_string(), "input isn\'t a valid token string: abc");
}
}