use anyhow::{bail, Context, Result};
use base64::{read::DecoderReader, write::EncoderWriter};
use serde_json::{Map, Value};
use std::io::{Read, Write};
use crate::Verify;
type JwsHeader = Map<String, Value>;
static DOT_BYTE: u8 = b'.';
pub fn deserialize_selector<F, V>(
jws: &impl AsRef<[u8]>,
payload: &mut impl Read,
selector: F,
) -> Result<JwsHeader>
where
F: Fn(&JwsHeader) -> Option<V>,
V: Verify,
{
let mut writer = DeserializeJwsWriter::new(jws, selector)?;
std::io::copy(payload, &mut writer)?;
writer.finish()
}
pub fn deserialize<V>(
jws: &impl AsRef<[u8]>,
payload: &mut impl Read,
verifier: V,
) -> Result<JwsHeader>
where
V: Verify,
{
let mut writer = DeserializeJwsWriter::new(jws, move |_| Some(verifier))?;
std::io::copy(payload, &mut writer)?;
writer.finish()
}
pub struct DeserializeJwsWriter<V: Write> {
encoder: EncoderWriter<V>,
header: Option<JwsHeader>,
signature: Vec<u8>,
}
impl<V> DeserializeJwsWriter<V>
where
V: Verify,
{
pub fn new<S>(jws: &impl AsRef<[u8]>, selector: S) -> Result<Self>
where
S: FnOnce(&JwsHeader) -> Option<V>,
{
let input = jws.as_ref();
let mut splits = input.split(|e| e == &DOT_BYTE);
let encoded_header = splits.next().context("wrong jws format")?.to_vec();
let header = {
let mut slice = encoded_header.as_slice();
let decoder = DecoderReader::new(&mut slice, base64::URL_SAFE_NO_PAD);
serde_json::from_reader(decoder).context("wrong jws header format")?
};
let mut splits = splits.skip(1);
let signature = {
let mut part3 = splits.next().context("wrong jws format")?;
base64::decode_config(&mut part3, base64::URL_SAFE_NO_PAD)
.context("wrong jws signature format")?
};
let mut verifier = selector(&header).context("verifier is not found")?;
verifier.write_all(encoded_header.as_slice())?;
verifier.write_all(&[DOT_BYTE])?;
Ok(Self {
encoder: EncoderWriter::new(verifier, base64::URL_SAFE_NO_PAD),
header: Some(header),
signature: signature,
})
}
pub fn finish(&mut self) -> Result<JwsHeader> {
if self.header.is_none() {
bail!("Derializer has already had finish() called")
};
let verifier = self.encoder.finish()?;
match verifier.verify(&self.signature)? {
true => Ok(self.header.take().unwrap()),
false => bail!("incorrect signature"),
}
}
}
impl<V> Write for DeserializeJwsWriter<V>
where
V: Verify,
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.encoder.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.encoder.flush()
}
}