1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use anyhow::bail;
use base64::Engine;
use serde::{Deserialize, Serialize};
use serde::de::DeserializeOwned;
use base64::prelude::{BASE64_URL_SAFE_NO_PAD};

#[derive(Debug)]
pub struct JwtParser<'a, T> {
    pub parts: Vec<&'a str>,
    pub header: JwtHeader,
    pub payload: T,
    pub sig: Vec<u8>,
}

impl<'a, T> JwtParser<'a, T> where T: DeserializeOwned {
    pub fn parse(token: &'a str) -> anyhow::Result<JwtParser<'a, T>> {
        let parts: Vec<&'a str> = token.split('.').collect();

        if parts.len() != 3 {
            bail!("invalid jwt token, token must have 3 segments, but get {}", parts.len());
        }

        Ok(Self {
            parts: parts.clone(),

            header: Self::parse_header(parts[0])?,
            payload: Self::parse_payload(parts[1])?,
            sig: Self::parse_signature(parts[2])?,
        })
    }

    fn parse_header(header: &str) -> anyhow::Result<JwtHeader> {
        let header = BASE64_URL_SAFE_NO_PAD.decode(header)?;
        let jh: JwtHeader = serde_json::from_slice(header.as_slice())?;

        Ok(jh)
    }

    fn parse_payload(payload: &str) -> anyhow::Result<T> {
        let payload = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
        let payload: T = serde_json::from_slice(payload.as_slice())?;

        Ok(payload)
    }

    fn parse_signature(sig: &str) -> anyhow::Result<Vec<u8>> {
        let sig = BASE64_URL_SAFE_NO_PAD.decode(sig)?;

        Ok(sig)
    }

    pub fn msg(&self) -> String {
        self.parts[0].to_string() + "." + self.parts[1]
    }
}

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct JwtHeader {
    pub alg: String,
    pub typ: String,
    pub kid: String,
}