use std::fmt;
use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD};
use serde::{Deserialize, Serialize};
use crate::{Header, Payload, SignatureVerifier, Validator};
use super::Error;
#[derive(Clone, Copy)]
pub struct RawJwt<S: AsRef<str>> {
raw: S,
dot0: u32,
dot1: u32,
}
impl<S: AsRef<str>> RawJwt<S> {
pub fn new(raw: S) -> Result<Self, Error> {
let mut dots = raw.as_ref().match_indices('.').map(|(pos, _)| pos);
let dot0 = dots.next().ok_or(Error::InvalidSyntax)?.try_into().unwrap();
let dot1 = dots.next().ok_or(Error::InvalidSyntax)?.try_into().unwrap();
if dots.next().is_some() {
return Err(Error::InvalidSyntax);
}
Ok(Self { raw, dot0, dot1 })
}
pub fn raw_header(&self) -> &str {
&self.raw.as_ref()[..self.dot0 as usize]
}
pub fn raw_payload(&self) -> &str {
&self.raw.as_ref()[(self.dot0 + 1) as usize..self.dot1 as usize]
}
pub fn raw_signature(&self) -> &str {
&self.raw.as_ref()[(self.dot1 + 1) as usize..]
}
pub fn raw_message(&self) -> &str {
&self.raw.as_ref()[..self.dot1 as usize]
}
pub fn decode_header(&self) -> Result<String, Error> {
let bytes = BASE64_URL_SAFE_NO_PAD.decode(self.raw_header())
.map_err(|_| Error::InvalidSyntax)?;
String::from_utf8(bytes).map_err(|_| Error::InvalidUtf8)
}
pub fn dangerously_unchecked_decode_payload(&self) -> Result<String, Error> {
let bytes = BASE64_URL_SAFE_NO_PAD.decode(self.raw_payload())
.map_err(|_| Error::InvalidSyntax)?;
String::from_utf8(bytes).map_err(|_| Error::InvalidUtf8)
}
pub fn decode_signature(&self) -> Result<Vec<u8>, Error> {
BASE64_URL_SAFE_NO_PAD.decode(self.raw_signature()).map_err(|_| Error::InvalidSyntax)
}
pub async fn decode<H, P, R>(
&self,
signature_verifier: &(impl ?Sized + SignatureVerifier<H>),
validator: &impl Validator<H, P>,
callback: impl FnOnce(Header<H>, Payload<P>) -> R,
) -> Result<R, Error>
where
H: for<'de> Deserialize<'de>,
P: for<'de> Deserialize<'de>,
{
let signature = self.decode_signature()?;
let header_str = self.decode_header()?;
let header = <Header<H>>::from_str(&header_str)?;
signature_verifier.verify(&header, self.raw_message(), &signature).await?;
let payload_str = self.dangerously_unchecked_decode_payload()?;
let payload = <Payload<P>>::from_str(&payload_str)?;
validator.validate(&header, &payload)?;
let out = callback(header, payload);
Ok(out)
}
pub fn dot0(&self) -> u32 {
self.dot0
}
pub fn dot1(&self) -> u32 {
self.dot1
}
pub fn full_raw_string(&self) -> &str {
self.raw.as_ref()
}
pub fn into_inner(self) -> S {
self.raw
}
}
impl RawJwt<&str> {
pub fn to_owned(&self) -> RawJwt<String> {
RawJwt {
raw: self.raw.to_owned(),
dot0: self.dot0,
dot1: self.dot1,
}
}
}
impl<S: AsRef<str>> AsRef<str> for RawJwt<S> {
fn as_ref(&self) -> &str {
self.raw.as_ref()
}
}
impl TryFrom<String> for RawJwt<String> {
type Error = Error;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl<'a> TryFrom<&'a str> for RawJwt<&'a str> {
type Error = Error;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl<'de, S: 'de + AsRef<str> + Deserialize<'de>> Deserialize<'de> for RawJwt<S> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = S::deserialize(deserializer)?;
Self::new(s).map_err(|_| <D::Error as serde::de::Error>::custom("invalid JWT"))
}
}
impl<S: AsRef<str> + Serialize> Serialize for RawJwt<S> {
fn serialize<R>(&self, serializer: R) -> Result<R::Ok, R::Error>
where
R: serde::Serializer,
{
self.raw.serialize(serializer)
}
}
impl<S: AsRef<str>> fmt::Display for RawJwt<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.raw.as_ref().fmt(f)
}
}
impl<S: AsRef<str>> fmt::Debug for RawJwt<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.raw.as_ref().fmt(f)
}
}
#[cfg(test)]
mod tests {
use crate::RawJwt;
#[test]
fn simple() {
let jwt = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.\
eyJleHAiOjE3NjE4MzcyNDF9.\
b5IRG8bphWWlH9OBycX9A4i9wHOrt6ypYkb3b8IxM1gqVvjKU8RCrn-OXyHUCsfsn3FOHpnoFBzLm0WUKFztDA";
let raw = RawJwt::new(jwt).unwrap();
assert_eq!(raw.raw_header(), "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9");
assert_eq!(raw.raw_payload(), "eyJleHAiOjE3NjE4MzcyNDF9");
assert_eq!(raw.raw_message(), "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3NjE4MzcyNDF9");
assert_eq!(raw.full_raw_string(), jwt);
assert_eq!(raw.raw_signature(), "b5IRG8bphWWlH9OBycX9A4i9wHOrt6ypYkb3b8IxM1gqVvjKU8RCrn-OXyHUCsfsn3FOHpnoFBzLm0WUKFztDA");
assert_eq!(raw.decode_header().unwrap(), r#"{"alg":"EdDSA","typ":"JWT"}"#);
assert_eq!(raw.decode_signature().unwrap(), vec![
0x6f, 0x92, 0x11, 0x1b, 0xc6, 0xe9, 0x85, 0x65,
0xa5, 0x1f, 0xd3, 0x81, 0xc9, 0xc5, 0xfd, 0x03,
0x88, 0xbd, 0xc0, 0x73, 0xab, 0xb7, 0xac, 0xa9,
0x62, 0x46, 0xf7, 0x6f, 0xc2, 0x31, 0x33, 0x58,
0x2a, 0x56, 0xf8, 0xca, 0x53, 0xc4, 0x42, 0xae,
0x7f, 0x8e, 0x5f, 0x21, 0xd4, 0x0a, 0xc7, 0xec,
0x9f, 0x71, 0x4e, 0x1e, 0x99, 0xe8, 0x14, 0x1c,
0xcb, 0x9b, 0x45, 0x94, 0x28, 0x5c, 0xed, 0x0c,
]);
}
}