1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#[cfg(test)]
mod test;

pub mod error;
pub mod util;

use base64::{self, Engine};
use error::JWTError;
use error_stack::{IntoReport, Report, ResultExt};
use serde::{Deserialize, Serialize};
use serde_json::{from_slice, to_string, Value};
use std::collections::BTreeMap;
use time::{serde::timestamp, OffsetDateTime};

#[derive(Debug, Deserialize, Serialize, Clone, Eq, PartialEq)]
pub enum JWTAlgorithm {
    #[serde(rename = "none")]
    NONE,
    HS256,
    HS384,
    HS512,
    RS256,
    RS384,
    RS512,
    ES256,
    ES384,
    ES512,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TokenHeader {
    pub alg: JWTAlgorithm,
    pub kid: Option<String>,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TokenClaims {
    #[serde(with = "timestamp")]
    pub exp: OffsetDateTime,
    #[serde(with = "timestamp")]
    pub iat: OffsetDateTime,
    pub aud: String,
    pub iss: String,
    pub sub: String,
    #[serde(with = "timestamp")]
    pub auth_time: OffsetDateTime,
}

#[derive(Debug, Clone)]
pub struct JWToken {
    pub header: TokenHeader,
    pub critical_claims: TokenClaims,
    pub all_claims: BTreeMap<String, Value>,
    pub payload: String,
    pub signature: Vec<u8>,
}

impl JWToken {
    pub fn from_encoded(encoded: &str) -> Result<Self, Report<JWTError>> {
        let mut parts = encoded.split('.');

        let header_slice = parts.next().ok_or(Report::new(JWTError::MissingHeader))?;

        let header: TokenHeader = from_slice(
            &base64::engine::general_purpose::URL_SAFE_NO_PAD
                .decode(header_slice)
                .into_report()
                .change_context(JWTError::FailedToParse)?,
        )
        .into_report()
        .change_context(JWTError::FailedToParse)?;

        let claims_slice = parts.next().ok_or(Report::new(JWTError::MissingHeader))?;
        let claims = base64::engine::general_purpose::URL_SAFE_NO_PAD
            .decode(claims_slice)
            .into_report()
            .change_context(JWTError::FailedToParse)?;

        let critical_claims: TokenClaims = from_slice(&claims)
            .into_report()
            .change_context(JWTError::FailedToParse)?;
        let all_claims: BTreeMap<String, Value> = from_slice(&claims)
            .into_report()
            .change_context(JWTError::FailedToParse)?;

        let signature = base64::engine::general_purpose::URL_SAFE_NO_PAD
            .decode(
                parts
                    .next()
                    .ok_or(Report::new(JWTError::MissingSignature))?,
            )
            .into_report()
            .change_context(JWTError::FailedToParse)?;

        Ok(Self {
            header,
            critical_claims,
            all_claims,
            payload: String::new() + header_slice + "." + claims_slice,
            signature,
        })
    }
}

pub trait JwtSigner {
    fn sign_jwt(&mut self, header: &str, payload: &str) -> Result<String, Report<JWTError>>;
}

/// Utility method for generating JWTs
pub fn encode_jwt<HeaderT: Serialize, PayloadT: Serialize, SignerT: JwtSigner>(
    header: &HeaderT,
    payload: &PayloadT,
    mut signer: SignerT,
) -> Result<String, Report<JWTError>> {
    let encoded_header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(
        to_string(header)
            .into_report()
            .change_context(JWTError::FailedToEncode)?,
    );

    let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(
        to_string(payload)
            .into_report()
            .change_context(JWTError::FailedToEncode)?,
    );

    let encoded_signature = signer.sign_jwt(&encoded_header, &encoded_payload)?;

    Ok(encoded_header + "." + &encoded_payload + "." + &encoded_signature)
}