qm_keycloak/token/
jwt.rs

1use std::{collections::HashSet, sync::Arc};
2
3use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ResourceAccess {
8    pub account: RealmAccess,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RealmAccess {
13    pub roles: Vec<Arc<str>>,
14}
15
16#[derive(Serialize, Clone, Deserialize, Default)]
17pub struct PartialClaims {
18    pub iss: String,
19    pub azp: String,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct Claims {
24    pub exp: i64,
25    pub iat: i64,
26    pub auth_time: Option<i64>,
27    pub jti: String,
28    pub iss: String,
29    pub aud: serde_json::Value,
30    pub sub: Arc<str>,
31    pub typ: String,
32    pub azp: String,
33    pub acr: String,
34    #[serde(rename = "allowed-origins")]
35    pub allowed_origins: Option<Vec<Arc<str>>>,
36    pub realm_access: RealmAccess,
37    pub resource_access: ResourceAccess,
38    #[serde(default)]
39    pub scope: String,
40    #[serde(default)]
41    pub sid: String,
42    #[serde(default)]
43    pub email_verified: bool,
44    #[serde(default)]
45    pub name: String,
46    #[serde(default)]
47    pub preferred_username: String,
48    #[serde(default)]
49    pub given_name: String,
50    #[serde(default)]
51    pub family_name: String,
52    #[serde(default)]
53    pub email: String,
54    #[serde(skip)]
55    pub is_api_test: bool,
56}
57
58impl Default for Claims {
59    fn default() -> Self {
60        Self {
61            exp: 0,
62            iat: 0,
63            auth_time: None,
64            jti: "".to_string(),
65            iss: "".to_string(),
66            is_api_test: true,
67            sub: Arc::from("user-id"),
68            typ: "".to_string(),
69            azp: "".to_string(),
70            acr: "".to_string(),
71            allowed_origins: None,
72            realm_access: RealmAccess { roles: vec![] },
73            resource_access: ResourceAccess {
74                account: RealmAccess { roles: vec![] },
75            },
76            scope: "".to_string(),
77            sid: "".to_string(),
78            email_verified: false,
79            name: "".to_string(),
80            preferred_username: "".to_string(),
81            given_name: "".to_string(),
82            family_name: "".to_string(),
83            aud: Default::default(),
84            email: "".to_string(),
85        }
86    }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct LogoutClaims {
91    pub iat: i64,
92    pub jti: String,
93    pub iss: String,
94    pub aud: serde_json::Value,
95    pub sub: String,
96    pub typ: String,
97    pub sid: String,
98}
99
100#[derive(Clone)]
101pub struct Jwt {
102    pub kid: String,
103    validation: Validation,
104    logout_validation: Validation,
105    decoding_key: DecodingKey,
106}
107
108impl Jwt {
109    pub fn new(
110        alg: Algorithm,
111        kid: String,
112        public_key: &str,
113        client_id: &str,
114    ) -> anyhow::Result<Self> {
115        let mut validation = Validation::new(alg);
116        validation.set_audience(&[client_id, "account"]);
117        // needed workaround to validate logout tokens (they contain no exp field)
118        let mut logout_validation = Validation::new(alg);
119        logout_validation.validate_exp = false;
120        logout_validation.required_spec_claims = HashSet::new();
121        logout_validation
122            .required_spec_claims
123            .insert("sub".to_string());
124        logout_validation
125            .required_spec_claims
126            .insert("iss".to_string());
127        logout_validation
128            .required_spec_claims
129            .insert("aud".to_string());
130        Ok(Self {
131            kid,
132            validation,
133            logout_validation,
134            decoding_key: DecodingKey::from_rsa_pem(
135                format!("-----BEGIN PUBLIC KEY-----\n{public_key}\n-----END PUBLIC KEY-----")
136                    .as_bytes(),
137            )?,
138        })
139    }
140
141    pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
142        self.decode_custom(token)
143    }
144
145    pub fn decode_custom<C: DeserializeOwned + Clone>(&self, token: &str) -> anyhow::Result<C> {
146        let result = decode(token, &self.decoding_key, &self.validation)?;
147        Ok(result.claims)
148    }
149
150    pub fn decode_logout_token(&self, token: &str) -> anyhow::Result<LogoutClaims> {
151        self.decode_logout_token_custom(token)
152    }
153
154    pub fn decode_logout_token_custom<C: DeserializeOwned + Clone>(
155        &self,
156        token: &str,
157    ) -> anyhow::Result<C> {
158        let result = decode(token, &self.decoding_key, &self.logout_validation)?;
159        Ok(result.claims)
160    }
161}