casdoor_sdk_rust/authn/
models.rs

1use std::time::Duration;
2use chrono::{DateTime, Utc};
3use crate::{Model, User};
4use anyhow::Result;
5use oauth2::{
6    basic::{BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse, BasicTokenType},
7    AccessToken, AuthType, AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, EndpointNotSet, EndpointSet, ExtraTokenFields,
8    IntrospectionUrl, RedirectUrl, RefreshToken, Scope, StandardRevocableToken, StandardTokenResponse, TokenUrl,
9    TokenResponse
10};
11use serde_with::{serde_as, TimestampSeconds};
12use reqwest::{redirect, ClientBuilder};
13use serde::{Deserialize, Serialize};
14
15type NumericDate = DateTime<Utc>;
16
17type ClaimStrings = Vec<String>;
18
19#[derive(Debug, thiserror::Error)]
20pub enum ValidationError {
21    #[error("token is expired")]
22    Expired,
23    #[error("token used before issued")]
24    IssuedAt,
25    #[error("token is not valid yet")]
26    NotValidYet,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, Default)]
30#[serde(rename_all = "camelCase", default)]
31pub struct ClaimsStandard {
32    #[serde(flatten)]
33    pub user: User,
34    pub email_verified: bool,
35    pub phone_number: String,
36    pub phone_number_verified: bool,
37    pub gender: String,
38    pub token_type: Option<String>,
39    pub nonce: Option<String>,
40    pub scope: Option<String>,
41    pub address: OIDCAddress,
42    pub tag: String,
43    #[serde(flatten)]
44    pub reg_claims: RegisteredClaims,
45}
46
47#[derive(Serialize, Deserialize, Default, Clone, Debug)]
48#[serde(default)]
49pub struct OIDCAddress {
50    #[serde(rename = "formatted")]
51    pub formatted: String,
52    #[serde(rename = "street_address")]
53    pub street_address: String,
54    #[serde(rename = "locality")]
55    pub locality: String,
56    #[serde(rename = "region")]
57    pub region: String,
58    #[serde(rename = "postal_code")]
59    pub postal_code: String,
60    #[serde(rename = "country")]
61    pub country: String,
62}
63
64#[serde_as]
65#[derive(Serialize, Deserialize, Debug, Clone, Default)]
66#[serde(default)]
67pub struct RegisteredClaims {
68    #[serde(rename = "iss", skip_serializing_if = "Option::is_none")]
69    pub issuer: Option<String>,
70    #[serde(rename = "sub", skip_serializing_if = "Option::is_none")]
71    pub subject: Option<String>,
72    #[serde(rename = "aud", skip_serializing_if = "Vec::is_empty")]
73    pub audience: ClaimStrings,
74    #[serde(rename = "exp", skip_serializing_if = "Option::is_none")]
75    #[serde_as(as = "Option<TimestampSeconds<i64>>")]
76    pub expires_at: Option<NumericDate>,
77    #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")]
78    #[serde_as(as = "Option<TimestampSeconds<i64>>")]
79    pub not_before: Option<NumericDate>,
80    #[serde(rename = "iat",skip_serializing_if = "Option::is_none")]
81    #[serde_as(as = "Option<TimestampSeconds<i64>>")]
82    pub issued_at: Option<NumericDate>,
83    #[serde(rename = "jti", skip_serializing_if = "Option::is_none")]
84    pub id: Option<String>,
85}
86
87impl RegisteredClaims {
88    pub fn valid(&self) -> Result<(), ValidationError> {
89        let now = Utc::now();
90
91        if !self.verify_expires_at(now, false) {
92            return Err(ValidationError::Expired);
93        }
94
95        if !self.verify_issued_at(now, false) {
96            return Err(ValidationError::IssuedAt);
97        }
98
99        if !self.verify_not_before(now, false) {
100            return Err(ValidationError::NotValidYet);
101        }
102
103        Ok(())
104    }
105
106    pub fn verify_expires_at(&self, cmp: NumericDate, require: bool) -> bool {
107        if cmp.timestamp().eq(&0) {
108            return !require;
109        }
110        if let Some(exp) = self.expires_at {
111            return cmp < exp;
112        }
113
114        !require
115    }
116
117    pub fn verify_issued_at(&self, cmp: NumericDate, require: bool) -> bool {
118        if cmp.timestamp().eq(&0) {
119            return !require;
120        }
121        if let Some(iat) = self.issued_at {
122            return cmp >= iat;
123        }
124
125        !require
126    }
127    pub fn verify_not_before(&self, cmp: NumericDate, require: bool) -> bool {
128        if cmp.timestamp().eq(&0) {
129            return !require;
130        }
131        if let Some(nbf) = self.not_before {
132            return cmp >= nbf;
133        }
134
135        !require
136    }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
140#[serde(rename_all = "camelCase", default)]
141pub struct Session {
142    owner: String,
143    name: String,
144    application: String,
145    created_time: String,
146    session_id: Vec<String>,
147}
148
149impl Session {
150    pub fn get_pk_id(&self) -> String {
151        format!("{}/{}/{}", self.owner, self.name, self.application)
152    }
153}
154
155impl Model for Session {
156    fn ident() -> &'static str {
157        "session"
158    }
159    fn plural_ident() -> &'static str {
160        "sessions"
161    }
162    fn support_update_columns() -> bool {
163        true
164    }
165    fn owner(&self) -> &str {
166        &self.owner
167    }
168    fn name(&self) -> &str {
169        &self.name
170    }
171}
172
173impl ExtraTokenFields for CasdoorExtraTokenFields {}
174
175#[derive(Debug, Deserialize, Serialize)]
176pub struct CasdoorExtraTokenFields {
177    /// This field only use in OpenID Connect
178    pub id_token: String,
179}
180
181pub type CasdoorTokenResponse = StandardTokenResponse<CasdoorExtraTokenFields, BasicTokenType>;
182
183pub type CasdoorClient<
184    HasAuthUrl = EndpointSet,
185    HasDeviceAuthUrl = EndpointNotSet,
186    HasIntrospectionUrl = EndpointNotSet,
187    HasRevocationUrl = EndpointNotSet,
188    HasTokenUrl = EndpointNotSet,
189> = Client<
190    BasicErrorResponse,
191    CasdoorTokenResponse,
192    BasicTokenIntrospectionResponse,
193    StandardRevocableToken,
194    BasicRevocationErrorResponse,
195    HasAuthUrl,
196    HasDeviceAuthUrl,
197    HasIntrospectionUrl,
198    HasRevocationUrl,
199    HasTokenUrl,
200>;
201
202#[derive(Clone, Debug, Deserialize, Serialize)]
203pub struct CasdoorResponse<EF: ExtraTokenFields> {
204    pub access_token: AccessToken,
205    pub token_type: BasicTokenType,
206    #[serde(skip_serializing_if = "Option::is_none")]
207    pub expires_in: Option<u64>,
208    #[serde(skip_serializing_if = "Option::is_none")]
209    pub refresh_token: Option<RefreshToken>,
210    #[serde(rename = "scope")]
211    #[serde(deserialize_with = "oauth2::helpers::deserialize_space_delimited_vec")]
212    #[serde(serialize_with = "oauth2::helpers::serialize_space_delimited_vec")]
213    #[serde(skip_serializing_if = "Option::is_none")]
214    #[serde(default)]
215    pub scopes: Option<Vec<Scope>>,
216
217    #[serde(bound = "EF: ExtraTokenFields")]
218    #[serde(flatten)]
219    pub extra_fields: EF,
220}
221
222impl<EF> TokenResponse for CasdoorResponse<EF>
223where
224    EF: ExtraTokenFields,
225{
226    type TokenType = BasicTokenType;
227    /// REQUIRED. The access token issued by the authorization server.
228    fn access_token(&self) -> &AccessToken {
229        &self.access_token
230    }
231    /// REQUIRED. The type of the token issued as described in
232    /// [Section 7.1](https://tools.ietf.org/html/rfc6749#section-7.1).
233    /// Value is case insensitive and deserialized to the generic `TokenType` parameter.
234    /// But in this particular case as the service is non compliant, it has a default value
235    fn token_type(&self) -> &BasicTokenType {
236        &self.token_type
237    }
238    /// RECOMMENDED. The lifetime in seconds of the access token. For example, the value 3600
239    /// denotes that the access token will expire in one hour from the time the response was
240    /// generated. If omitted, the authorization server SHOULD provide the expiration time via
241    /// other means or document the default value.
242    fn expires_in(&self) -> Option<Duration> {
243        self.expires_in.map(Duration::from_secs)
244    }
245    /// OPTIONAL. The refresh token, which can be used to obtain new access tokens using the same
246    /// authorization grant as described in
247    /// [Section 6](https://tools.ietf.org/html/rfc6749#section-6).
248    fn refresh_token(&self) -> Option<&RefreshToken> {
249        self.refresh_token.as_ref()
250    }
251    /// OPTIONAL, if identical to the scope requested by the client; otherwise, REQUIRED. The
252    /// scope of the access token as described by
253    /// [Section 3.3](https://tools.ietf.org/html/rfc6749#section-3.3). If included in the response,
254    /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from
255    /// the response, this field is `None`.
256    fn scopes(&self) -> Option<&Vec<Scope>> {
257        self.scopes.as_ref()
258    }
259}
260
261pub struct OAuth2Client {
262    pub client: CasdoorClient,
263    pub http_client: reqwest::Client,
264}
265
266impl OAuth2Client {
267    pub(crate) async fn new(client_id: ClientId, client_secret: ClientSecret, auth_url: AuthUrl) -> Result<Self> {
268        let http_client = ClientBuilder::new()
269            .redirect(redirect::Policy::default())
270            .build()
271            .expect("Client must build");
272
273        let client = CasdoorClient::new(client_id)
274            .set_client_secret(client_secret)
275            .set_auth_uri(auth_url);
276
277        Ok(Self { client, http_client })
278    }
279
280    pub async fn refresh_token(self, refresh_token: RefreshToken, token_url: TokenUrl)
281        -> Result<CasdoorTokenResponse> {
282        let token_res: CasdoorTokenResponse = self
283            .client
284            .set_auth_type(AuthType::RequestBody)
285            .set_token_uri(token_url)
286            .exchange_refresh_token(&refresh_token)
287            .add_scope(Scope::new("read".to_string()))
288            .request_async(&self.http_client)
289            .await?;
290
291        Ok(token_res)
292    }
293
294    pub async fn get_oauth_token(self, code: AuthorizationCode, redirect_url: RedirectUrl, token_url: TokenUrl)
295        -> Result<CasdoorTokenResponse> {
296        let token_res = self
297            .client
298            .set_auth_type(AuthType::RequestBody)
299            .set_redirect_uri(redirect_url)
300            .set_token_uri(token_url)
301            .exchange_code(code)
302            .request_async(&self.http_client)
303            .await?;
304
305        Ok(token_res)
306    }
307
308    pub async fn get_introspect_access_token(self, intro_url: IntrospectionUrl, token: &AccessToken)
309        -> Result<BasicTokenIntrospectionResponse> {
310        let res = self
311            .client
312            .set_auth_type(AuthType::BasicAuth)
313            .set_introspection_url(intro_url)
314            .introspect(token)
315            .set_token_type_hint("access_token")
316            .request_async(&self.http_client)
317            .await?;
318
319        Ok(res)
320    }
321}