1use std::{collections::HashSet, sync::Arc};
2
3use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ResourceAccess {
8 pub account: RealmAccess,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RealmAccess {
13 pub roles: Vec<Arc<str>>,
14}
15
16#[derive(Serialize, Clone, Deserialize, Default)]
17pub struct PartialClaims {
18 pub iss: String,
19 pub azp: String,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct Claims {
24 pub exp: i64,
25 pub iat: i64,
26 pub auth_time: Option<i64>,
27 pub jti: String,
28 pub iss: String,
29 pub aud: serde_json::Value,
30 pub sub: Arc<str>,
31 pub typ: String,
32 pub azp: String,
33 pub acr: String,
34 #[serde(rename = "allowed-origins")]
35 pub allowed_origins: Option<Vec<Arc<str>>>,
36 pub realm_access: RealmAccess,
37 pub resource_access: ResourceAccess,
38 #[serde(default)]
39 pub scope: String,
40 #[serde(default)]
41 pub sid: String,
42 #[serde(default)]
43 pub email_verified: bool,
44 #[serde(default)]
45 pub name: String,
46 #[serde(default)]
47 pub preferred_username: String,
48 #[serde(default)]
49 pub given_name: String,
50 #[serde(default)]
51 pub family_name: String,
52 #[serde(default)]
53 pub email: String,
54 #[serde(skip)]
55 pub is_api_test: bool,
56}
57
58impl Default for Claims {
59 fn default() -> Self {
60 Self {
61 exp: 0,
62 iat: 0,
63 auth_time: None,
64 jti: "".to_string(),
65 iss: "".to_string(),
66 is_api_test: true,
67 sub: Arc::from("user-id"),
68 typ: "".to_string(),
69 azp: "".to_string(),
70 acr: "".to_string(),
71 allowed_origins: None,
72 realm_access: RealmAccess { roles: vec![] },
73 resource_access: ResourceAccess {
74 account: RealmAccess { roles: vec![] },
75 },
76 scope: "".to_string(),
77 sid: "".to_string(),
78 email_verified: false,
79 name: "".to_string(),
80 preferred_username: "".to_string(),
81 given_name: "".to_string(),
82 family_name: "".to_string(),
83 aud: Default::default(),
84 email: "".to_string(),
85 }
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct LogoutClaims {
91 pub iat: i64,
92 pub jti: String,
93 pub iss: String,
94 pub aud: serde_json::Value,
95 pub sub: String,
96 pub typ: String,
97 pub sid: String,
98}
99
100#[derive(Clone)]
101pub struct Jwt {
102 pub kid: String,
103 validation: Validation,
104 logout_validation: Validation,
105 decoding_key: DecodingKey,
106}
107
108impl Jwt {
109 pub fn new(
110 alg: Algorithm,
111 kid: String,
112 public_key: &str,
113 client_id: &str,
114 ) -> anyhow::Result<Self> {
115 let mut validation = Validation::new(alg);
116 validation.set_audience(&[client_id, "account"]);
117 let mut logout_validation = Validation::new(alg);
119 logout_validation.validate_exp = false;
120 logout_validation.required_spec_claims = HashSet::new();
121 logout_validation
122 .required_spec_claims
123 .insert("sub".to_string());
124 logout_validation
125 .required_spec_claims
126 .insert("iss".to_string());
127 logout_validation
128 .required_spec_claims
129 .insert("aud".to_string());
130 Ok(Self {
131 kid,
132 validation,
133 logout_validation,
134 decoding_key: DecodingKey::from_rsa_pem(
135 format!("-----BEGIN PUBLIC KEY-----\n{public_key}\n-----END PUBLIC KEY-----")
136 .as_bytes(),
137 )?,
138 })
139 }
140
141 pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
142 self.decode_custom(token)
143 }
144
145 pub fn decode_custom<C: DeserializeOwned + Clone>(&self, token: &str) -> anyhow::Result<C> {
146 let result = decode(token, &self.decoding_key, &self.validation)?;
147 Ok(result.claims)
148 }
149
150 pub fn decode_logout_token(&self, token: &str) -> anyhow::Result<LogoutClaims> {
151 self.decode_logout_token_custom(token)
152 }
153
154 pub fn decode_logout_token_custom<C: DeserializeOwned + Clone>(
155 &self,
156 token: &str,
157 ) -> anyhow::Result<C> {
158 let result = decode(token, &self.decoding_key, &self.logout_validation)?;
159 Ok(result.claims)
160 }
161}