use std::collections::BTreeMap;
use base64::Engine;
use crate::{
Error,
JsonObject,
JsonValue,
Result,
Signer,
Verifier,
};
pub fn encode(header: &JsonObject, payload: &[u8]) -> EncodedMessage {
let header_json = serde_json::to_vec(&header).unwrap();
let output_len = base64_len(header_json.len()) + base64_len(payload.len()) + 1;
let mut data = String::with_capacity(output_len);
let base64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
base64.encode_string(&header_json, &mut data);
let header_length = data.len();
data.push('.');
base64.encode_string(&payload, &mut data);
EncodedMessage{data, header_length}
}
pub fn encode_sign(header: JsonObject, payload: &[u8], signer: &impl Signer) -> Result<EncodedSignedMessage> {
let mut header = header;
signer.set_header_params(&mut header);
let encoded = encode(&header, payload);
let signature = signer.compute_mac(encoded.header().as_bytes(), encoded.payload().as_bytes())?;
let header_length = encoded.header().len();
let payload_length = encoded.payload().len();
let mut data = encoded.into_data();
data.reserve(base64_len(signature.len()) + 1);
data.push('.');
let base64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
base64.encode_string(&signature, &mut data);
Ok(EncodedSignedMessage{data, header_length, payload_length})
}
#[deprecated = "this function was marked unsafe but has no safety implications, use decode_unverified instead"]
pub unsafe fn decode(data: &[u8]) -> Result<(DecodedMessage, Vec<u8>)> {
decode_unverified(data)
}
pub fn decode_unverified(data: &[u8]) -> Result<(DecodedMessage, Vec<u8>)> {
split_encoded_parts(data)?.decode()
}
pub fn decode_verify(data: &[u8], verifier: &impl Verifier) -> Result<DecodedMessage> {
let parts = split_encoded_parts(data)?;
let (message, signature) = parts.decode()?;
verifier.verify(Some(&message.header), None, parts.header, parts.payload, &signature)?;
Ok(message)
}
#[derive(Clone, Debug, PartialEq)]
pub struct DecodedMessage {
pub header : JsonObject,
pub payload : Vec<u8>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EncodedMessage {
data : String,
header_length : usize,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EncodedSignedMessage {
data : String,
header_length : usize,
payload_length : usize,
}
impl DecodedMessage {
pub fn new(header: impl Into<JsonObject>, payload: impl Into<Vec<u8>>) -> Self {
Self{header: header.into(), payload: payload.into()}
}
pub fn from_encoded_parts(header: &[u8], payload: &[u8]) -> Result<Self> {
let header = decode_base64_url(header, "header")?;
let payload = decode_base64_url(payload, "payload")?;
let header: BTreeMap<String, JsonValue> = decode_json(&header, "header")?;
Ok(Self{header, payload})
}
pub fn parse_json<'de, T: serde::de::Deserialize<'de> + 'de>(&'de self) -> std::result::Result<T, serde_json::Error> {
serde_json::from_slice(&self.payload)
}
pub fn parse_json_value(&self) -> std::result::Result<JsonValue, serde_json::Error> {
self.parse_json()
}
pub fn parse_json_object(&self) -> std::result::Result<JsonObject, serde_json::Error> {
self.parse_json()
}
}
impl EncodedMessage {
pub fn data(&self) -> &str {
&self.data
}
pub fn into_data(self) -> String {
self.data
}
pub fn as_bytes(&self) -> &[u8] {
self.data().as_bytes()
}
pub fn header(&self) -> &str {
&self.data[..self.header_length]
}
pub fn payload(&self) -> &str {
&self.data[self.header_length + 1..]
}
}
impl EncodedSignedMessage {
pub fn data(&self) -> &str {
&self.data
}
pub fn into_data(self) -> String {
self.data
}
pub fn as_bytes(&self) -> &[u8] {
self.data().as_bytes()
}
pub fn header(&self) -> &str {
&self.data[..self.header_length]
}
pub fn payload(&self) -> &str {
&self.data[self.payload_start()..self.payload_end()]
}
pub fn signature(&self) -> &str {
&self.data[self.signature_start()..]
}
pub fn parts(&self) -> CompactSerializedParts {
CompactSerializedParts {
header: self.header().as_bytes(),
payload: self.payload().as_bytes(),
signature: self.signature().as_bytes(),
}
}
fn payload_start(&self) -> usize {
self.header_length + 1
}
fn payload_end(&self) -> usize {
self.payload_start() + self.payload_length
}
fn signature_start(&self) -> usize {
self.payload_end() + 1
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct CompactSerializedParts<'a> {
pub header: &'a [u8],
pub payload: &'a [u8],
pub signature: &'a [u8],
}
impl<'a> CompactSerializedParts<'a> {
pub fn decode(&self) -> Result<(DecodedMessage, Vec<u8>)> {
let message = DecodedMessage::from_encoded_parts(self.header, self.payload)?;
let signature = decode_base64_url(self.signature, "signature")?;
Ok((message, signature))
}
}
pub fn split_encoded_parts(data: &[u8]) -> Result<CompactSerializedParts> {
let mut parts = data.splitn(4, |&c| c == b'.');
let header = parts.next().ok_or_else(|| Error::invalid_message("encoded message does not contain a header"))?;
let payload = parts.next().ok_or_else(|| Error::invalid_message("encoded message does not contain a payload"))?;
let signature = parts.next().ok_or_else(|| Error::invalid_message("encoded message does not contain a signature"))?;
if parts.next().is_some() {
return Err(Error::invalid_message("encoded message contains an additional field after the signature"));
}
Ok(CompactSerializedParts{header, payload, signature})
}
fn base64_len(input_len: usize) -> usize {
(input_len * 4 + 2) / 3
}
fn decode_base64_url(value: &[u8], field_name: &str) -> Result<Vec<u8>> {
let base64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
match base64.decode(value) {
Ok(x) => Ok(x),
Err(_) => Err(Error::invalid_message(format!("invalid base64 in {}", field_name)))
}
}
fn decode_json<'a, T: serde::Deserialize<'a>>(value: &'a [u8], field_name: &str) -> Result<T> {
match serde_json::from_slice(value) {
Ok(x) => Ok(x),
Err(_) => Err(Error::invalid_message(format!("invalid JSON in {}", field_name)))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::json_object;
use assert2::assert;
fn test_split_valid(source: &[u8], header: &[u8], payload: &[u8], signature: &[u8]) {
let parts = split_encoded_parts(source).unwrap();
assert!(parts.header == header);
assert!(parts.payload == payload);
assert!(parts.signature == signature);
}
#[test]
fn test_split_encoded_parts() {
test_split_valid(b"..", b"", b"", b"");
test_split_valid(b"..mies", b"", b"", b"mies");
test_split_valid(b".noot.", b"", b"noot", b"");
test_split_valid(b".noot.mies", b"", b"noot", b"mies");
test_split_valid(b"aap..", b"aap", b"", b"");
test_split_valid(b"aap..mies", b"aap", b"", b"mies");
test_split_valid(b"aap.noot.", b"aap", b"noot", b"");
test_split_valid(b"aap.noot.mies", b"aap", b"noot", b"mies");
assert!(let Err(Error { kind: Error::InvalidMessage, .. }) = split_encoded_parts(b"aapnootmies"));
assert!(let Err(Error { kind: Error::InvalidMessage, .. }) = split_encoded_parts(b"aap.nootmies"));
assert!(let Err(Error { kind: Error::InvalidMessage, .. }) = split_encoded_parts(b"aap.noot.mies."));
}
const RFC7515_A1_ENCODED : &[u8] = b"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
const RFC7515_A1_ENCODED_MANGLED : &[u8] = b"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqc2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
const RFC7515_A1_SIGNATURE : &[u8] = &[116, 24, 223, 180, 151, 153, 224, 37, 79, 250, 96, 125, 216, 173, 187, 186, 22, 212, 37, 77, 105, 214, 191, 240, 91, 88, 5, 88, 83, 132, 141, 121];
#[test]
fn test_decode() {
let (message, signature) = split_encoded_parts(RFC7515_A1_ENCODED).unwrap().decode().unwrap();
assert!(&message.header == &json_object!{
"alg": "HS256",
"typ": "JWT"
});
assert!(let Ok(_) = message.parse_json_object());
assert!(message.parse_json_object().ok() == Some(json_object!{
"iss": "joe",
"exp": 1300819380,
"http://example.com/is_root": true,
}));
assert!(&signature[..] == RFC7515_A1_SIGNATURE);
}
#[test]
fn test_decode_mangled() {
let (message, signature) = split_encoded_parts(RFC7515_A1_ENCODED_MANGLED).unwrap().decode().unwrap();
assert!(&message.header == &json_object!{
"alg": "HS256",
"typ": "JWT",
});
assert!(message.parse_json_object().unwrap() == json_object!{
"iss": "jse",
"exp": 1300819380,
"http://example.com/is_root": true,
});
assert!(&signature[..] == RFC7515_A1_SIGNATURE);
}
#[test]
fn test_encode() {
let header = json_object!{"typ": "JWT", "alg": "HS256"};
let encoded = encode(&header, b"foo");
assert!(encoded.header() == "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9");
assert!(encoded.payload() == "Zm9v");
assert!(encoded.data() == "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.Zm9v")
}
}