1use std::time::Duration;
2use chrono::{DateTime, Utc};
3use crate::{Model, User};
4use anyhow::Result;
5use oauth2::{
6 basic::{BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse, BasicTokenType},
7 AccessToken, AuthType, AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, EndpointNotSet, EndpointSet, ExtraTokenFields,
8 IntrospectionUrl, RedirectUrl, RefreshToken, Scope, StandardRevocableToken, StandardTokenResponse, TokenUrl,
9 TokenResponse
10};
11use serde_with::{serde_as, TimestampSeconds};
12use reqwest::{redirect, ClientBuilder};
13use serde::{Deserialize, Serialize};
14
15type NumericDate = DateTime<Utc>;
16
17type ClaimStrings = Vec<String>;
18
19#[derive(Debug, thiserror::Error)]
20pub enum ValidationError {
21 #[error("token is expired")]
22 Expired,
23 #[error("token used before issued")]
24 IssuedAt,
25 #[error("token is not valid yet")]
26 NotValidYet,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, Default)]
30#[serde(rename_all = "camelCase", default)]
31pub struct ClaimsStandard {
32 #[serde(flatten)]
33 pub user: User,
34 pub email_verified: bool,
35 pub phone_number: String,
36 pub phone_number_verified: bool,
37 pub gender: String,
38 pub token_type: Option<String>,
39 pub nonce: Option<String>,
40 pub scope: Option<String>,
41 pub address: OIDCAddress,
42 pub tag: String,
43 #[serde(flatten)]
44 pub reg_claims: RegisteredClaims,
45}
46
47#[derive(Serialize, Deserialize, Default, Clone, Debug)]
48#[serde(default)]
49pub struct OIDCAddress {
50 #[serde(rename = "formatted")]
51 pub formatted: String,
52 #[serde(rename = "street_address")]
53 pub street_address: String,
54 #[serde(rename = "locality")]
55 pub locality: String,
56 #[serde(rename = "region")]
57 pub region: String,
58 #[serde(rename = "postal_code")]
59 pub postal_code: String,
60 #[serde(rename = "country")]
61 pub country: String,
62}
63
64#[serde_as]
65#[derive(Serialize, Deserialize, Debug, Clone, Default)]
66#[serde(default)]
67pub struct RegisteredClaims {
68 #[serde(rename = "iss", skip_serializing_if = "Option::is_none")]
69 pub issuer: Option<String>,
70 #[serde(rename = "sub", skip_serializing_if = "Option::is_none")]
71 pub subject: Option<String>,
72 #[serde(rename = "aud", skip_serializing_if = "Vec::is_empty")]
73 pub audience: ClaimStrings,
74 #[serde(rename = "exp", skip_serializing_if = "Option::is_none")]
75 #[serde_as(as = "Option<TimestampSeconds<i64>>")]
76 pub expires_at: Option<NumericDate>,
77 #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")]
78 #[serde_as(as = "Option<TimestampSeconds<i64>>")]
79 pub not_before: Option<NumericDate>,
80 #[serde(rename = "iat",skip_serializing_if = "Option::is_none")]
81 #[serde_as(as = "Option<TimestampSeconds<i64>>")]
82 pub issued_at: Option<NumericDate>,
83 #[serde(rename = "jti", skip_serializing_if = "Option::is_none")]
84 pub id: Option<String>,
85}
86
87impl RegisteredClaims {
88 pub fn valid(&self) -> Result<(), ValidationError> {
89 let now = Utc::now();
90
91 if !self.verify_expires_at(now, false) {
92 return Err(ValidationError::Expired);
93 }
94
95 if !self.verify_issued_at(now, false) {
96 return Err(ValidationError::IssuedAt);
97 }
98
99 if !self.verify_not_before(now, false) {
100 return Err(ValidationError::NotValidYet);
101 }
102
103 Ok(())
104 }
105
106 pub fn verify_expires_at(&self, cmp: NumericDate, require: bool) -> bool {
107 if cmp.timestamp().eq(&0) {
108 return !require;
109 }
110 if let Some(exp) = self.expires_at {
111 return cmp < exp;
112 }
113
114 !require
115 }
116
117 pub fn verify_issued_at(&self, cmp: NumericDate, require: bool) -> bool {
118 if cmp.timestamp().eq(&0) {
119 return !require;
120 }
121 if let Some(iat) = self.issued_at {
122 return cmp >= iat;
123 }
124
125 !require
126 }
127 pub fn verify_not_before(&self, cmp: NumericDate, require: bool) -> bool {
128 if cmp.timestamp().eq(&0) {
129 return !require;
130 }
131 if let Some(nbf) = self.not_before {
132 return cmp >= nbf;
133 }
134
135 !require
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
140#[serde(rename_all = "camelCase", default)]
141pub struct Session {
142 owner: String,
143 name: String,
144 application: String,
145 created_time: String,
146 session_id: Vec<String>,
147}
148
149impl Session {
150 pub fn get_pk_id(&self) -> String {
151 format!("{}/{}/{}", self.owner, self.name, self.application)
152 }
153}
154
155impl Model for Session {
156 fn ident() -> &'static str {
157 "session"
158 }
159 fn plural_ident() -> &'static str {
160 "sessions"
161 }
162 fn support_update_columns() -> bool {
163 true
164 }
165 fn owner(&self) -> &str {
166 &self.owner
167 }
168 fn name(&self) -> &str {
169 &self.name
170 }
171}
172
173impl ExtraTokenFields for CasdoorExtraTokenFields {}
174
175#[derive(Debug, Deserialize, Serialize)]
176pub struct CasdoorExtraTokenFields {
177 pub id_token: String,
179}
180
181pub type CasdoorTokenResponse = StandardTokenResponse<CasdoorExtraTokenFields, BasicTokenType>;
182
183pub type CasdoorClient<
184 HasAuthUrl = EndpointSet,
185 HasDeviceAuthUrl = EndpointNotSet,
186 HasIntrospectionUrl = EndpointNotSet,
187 HasRevocationUrl = EndpointNotSet,
188 HasTokenUrl = EndpointNotSet,
189> = Client<
190 BasicErrorResponse,
191 CasdoorTokenResponse,
192 BasicTokenIntrospectionResponse,
193 StandardRevocableToken,
194 BasicRevocationErrorResponse,
195 HasAuthUrl,
196 HasDeviceAuthUrl,
197 HasIntrospectionUrl,
198 HasRevocationUrl,
199 HasTokenUrl,
200>;
201
202#[derive(Clone, Debug, Deserialize, Serialize)]
203pub struct CasdoorResponse<EF: ExtraTokenFields> {
204 pub access_token: AccessToken,
205 pub token_type: BasicTokenType,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 pub expires_in: Option<u64>,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 pub refresh_token: Option<RefreshToken>,
210 #[serde(rename = "scope")]
211 #[serde(deserialize_with = "oauth2::helpers::deserialize_space_delimited_vec")]
212 #[serde(serialize_with = "oauth2::helpers::serialize_space_delimited_vec")]
213 #[serde(skip_serializing_if = "Option::is_none")]
214 #[serde(default)]
215 pub scopes: Option<Vec<Scope>>,
216
217 #[serde(bound = "EF: ExtraTokenFields")]
218 #[serde(flatten)]
219 pub extra_fields: EF,
220}
221
222impl<EF> TokenResponse for CasdoorResponse<EF>
223where
224 EF: ExtraTokenFields,
225{
226 type TokenType = BasicTokenType;
227 fn access_token(&self) -> &AccessToken {
229 &self.access_token
230 }
231 fn token_type(&self) -> &BasicTokenType {
236 &self.token_type
237 }
238 fn expires_in(&self) -> Option<Duration> {
243 self.expires_in.map(Duration::from_secs)
244 }
245 fn refresh_token(&self) -> Option<&RefreshToken> {
249 self.refresh_token.as_ref()
250 }
251 fn scopes(&self) -> Option<&Vec<Scope>> {
257 self.scopes.as_ref()
258 }
259}
260
261pub struct OAuth2Client {
262 pub client: CasdoorClient,
263 pub http_client: reqwest::Client,
264}
265
266impl OAuth2Client {
267 pub(crate) async fn new(client_id: ClientId, client_secret: ClientSecret, auth_url: AuthUrl) -> Result<Self> {
268 let http_client = ClientBuilder::new()
269 .redirect(redirect::Policy::default())
270 .build()
271 .expect("Client must build");
272
273 let client = CasdoorClient::new(client_id)
274 .set_client_secret(client_secret)
275 .set_auth_uri(auth_url);
276
277 Ok(Self { client, http_client })
278 }
279
280 pub async fn refresh_token(self, refresh_token: RefreshToken, token_url: TokenUrl)
281 -> Result<CasdoorTokenResponse> {
282 let token_res: CasdoorTokenResponse = self
283 .client
284 .set_auth_type(AuthType::RequestBody)
285 .set_token_uri(token_url)
286 .exchange_refresh_token(&refresh_token)
287 .add_scope(Scope::new("read".to_string()))
288 .request_async(&self.http_client)
289 .await?;
290
291 Ok(token_res)
292 }
293
294 pub async fn get_oauth_token(self, code: AuthorizationCode, redirect_url: RedirectUrl, token_url: TokenUrl)
295 -> Result<CasdoorTokenResponse> {
296 let token_res = self
297 .client
298 .set_auth_type(AuthType::RequestBody)
299 .set_redirect_uri(redirect_url)
300 .set_token_uri(token_url)
301 .exchange_code(code)
302 .request_async(&self.http_client)
303 .await?;
304
305 Ok(token_res)
306 }
307
308 pub async fn get_introspect_access_token(self, intro_url: IntrospectionUrl, token: &AccessToken)
309 -> Result<BasicTokenIntrospectionResponse> {
310 let res = self
311 .client
312 .set_auth_type(AuthType::BasicAuth)
313 .set_introspection_url(intro_url)
314 .introspect(token)
315 .set_token_type_hint("access_token")
316 .request_async(&self.http_client)
317 .await?;
318
319 Ok(res)
320 }
321}