oauth21_server/
provider.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use url::{Host::Domain, Url};
5
6use crate::{
7    client::{Client, GrantType, ResponseType, TokenEndpointAuthMethod},
8    crypto::{decode_base64, random_secure_string, sha256},
9    error::{ErrorCode, OAuthError},
10};
11
12#[derive(Debug, Clone, Copy)]
13pub enum CodeChallengeMethod {
14    Plain,
15    S256,
16}
17
18pub struct AuthorizationRequest {
19    pub response_type: ResponseType,
20    pub client_id: String,
21    pub code_challenge: String,
22    pub code_challenge_method: Option<CodeChallengeMethod>,
23    pub redirect_uri: Option<String>,
24    pub scope: Option<String>,
25    pub state: Option<String>,
26}
27
28pub struct TokenRequest {
29    pub grant_type: GrantType,
30    pub code: String,
31    pub redirect_uri: Option<String>,
32    pub client_id: Option<String>,
33    pub code_verifier: String,
34}
35
36pub struct ClientCredentialsTokenRequest {
37    pub grant_type: GrantType,
38    pub scope: Option<String>,
39}
40
41#[derive(Debug)]
42pub struct AuthorizationResponse {
43    pub code: String,
44    pub state: Option<String>,
45}
46
47#[derive(Debug)]
48pub struct TokenResponse {
49    access_token: String,
50    token_type: String,
51    expires_in: u64,
52    scope: String,
53}
54
55pub struct VerifiedAuthorizationRequest {
56    client_id: String,
57    code_challenge: String,
58    code_challenge_method: CodeChallengeMethod,
59    redirect_uri: String,
60    scope: String,
61    state: Option<String>,
62}
63
64pub struct VerifiedTokenRequest {
65    authentication_information: AuthorizationInformation,
66}
67
68pub struct VerifiedClientCredentialsTokenRequest {
69    pub scope: String,
70}
71
72#[derive(Debug, Clone)]
73pub struct AuthorizationInformation {
74    client_id: String,
75    redirect_uri: String,
76    code_challenge: String,
77    code_challenge_method: CodeChallengeMethod,
78    scope: String,
79    state: Option<String>,
80    is_valid: bool,
81}
82
83#[derive(Debug)]
84pub struct SigninInformation {
85    client_name: String,
86    client_uri: String,
87    logo_uri: String,
88    scopes: Vec<String>,
89    contacts: Vec<String>,
90    tos_uri: String,
91    policy_uri: String,
92}
93
94#[async_trait(?Send)]
95pub trait Provider
96where
97    Self: Sized,
98{
99    async fn store_client(&self, client: Client) -> Result<(), OAuthError>;
100    async fn get_client(&self, client_id: &str) -> Option<Client>;
101    async fn save_authorization_information(
102        &self,
103        id: String,
104        information: AuthorizationInformation,
105    ) -> Result<(), OAuthError>;
106    async fn get_authorization_information(&self, id: &str) -> Option<AuthorizationInformation>;
107    async fn remove_authorization_information(&self, id: &str) -> Result<(), OAuthError>;
108
109    async fn verify_authorization_request<T: AuthorizationFlow>(
110        &self,
111        request: AuthorizationRequest,
112        flow: T,
113    ) -> Result<(VerifiedAuthorizationRequest, SigninInformation), OAuthError> {
114        let client = self
115            .get_client(&request.client_id)
116            .await
117            .ok_or(OAuthError::new(
118                ErrorCode::InvalidRequest,
119                request.state.clone(),
120            ))?;
121        let request = flow.verify_authorization(&client, request).await?;
122        Ok((
123            request,
124            SigninInformation {
125                client_name: client.name,
126                client_uri: client.uri,
127                logo_uri: client.logo_uri,
128                scopes: client.scopes,
129                contacts: client.contacts,
130                tos_uri: client.tos_uri,
131                policy_uri: client.policy_uri,
132            },
133        ))
134    }
135
136    async fn authorize<T: AuthorizationFlow>(
137        &self,
138        request: VerifiedAuthorizationRequest,
139        flow: T,
140    ) -> Result<T::Response, OAuthError> {
141        flow.perform_authorization(self, request).await
142    }
143
144    async fn get_token<T: TokenFlow>(
145        &self,
146        authenticated_client: Option<Client>,
147        request: T::Request,
148        flow: T,
149    ) -> Result<T::Response, OAuthError> {
150        let request = flow
151            .verify_token_request(authenticated_client, request, self)
152            .await?;
153        flow.perform_token_exchange(request).await
154    }
155}
156
157#[async_trait(?Send)]
158pub trait AuthorizationFlow {
159    type Response;
160
161    async fn verify_authorization(
162        &self,
163        client: &Client,
164        request: AuthorizationRequest,
165    ) -> Result<VerifiedAuthorizationRequest, OAuthError>;
166
167    async fn perform_authorization(
168        &self,
169        provider: &impl Provider,
170        request: VerifiedAuthorizationRequest,
171    ) -> Result<Self::Response, OAuthError>;
172}
173
174#[async_trait(?Send)]
175pub trait TokenFlow {
176    type Request;
177    type VerifiedRequest;
178    type Response;
179
180    async fn verify_token_request(
181        &self,
182        authenticated_client: Option<Client>,
183        request: Self::Request,
184        provider: &impl Provider,
185    ) -> Result<Self::VerifiedRequest, OAuthError>;
186
187    async fn perform_token_exchange(
188        &self,
189        request: Self::VerifiedRequest,
190    ) -> Result<Self::Response, OAuthError>;
191}
192
193pub struct AuthorizationCodeFlow;
194pub struct ClientCredentialsFlow;
195
196#[async_trait(?Send)]
197impl AuthorizationFlow for AuthorizationCodeFlow {
198    type Response = AuthorizationResponse;
199
200    async fn verify_authorization(
201        &self,
202        client: &Client,
203        request: AuthorizationRequest,
204    ) -> Result<VerifiedAuthorizationRequest, OAuthError> {
205        let state = request.state.clone();
206        if request.response_type != ResponseType::Code {
207            return Err(OAuthError::new(ErrorCode::InvalidRequest, state));
208        }
209
210        let code_challenge_method = if let Some(method) = request.code_challenge_method {
211            method
212        } else {
213            CodeChallengeMethod::Plain
214        };
215
216        let redirect_uri = match request.redirect_uri {
217            Some(ref redirect_uri) if !client.redirect_uris.contains(redirect_uri) => {
218                return Err(OAuthError::new(ErrorCode::InvalidRequest, state))
219            }
220            None if client.redirect_uris.len() > 1 => {
221                return Err(OAuthError::new(ErrorCode::InvalidRequest, state))
222            }
223            Some(redirect_uri) => redirect_uri,
224            None => client.redirect_uris[0].to_string(),
225        };
226
227        let redirect_uri_parsed = Url::parse(&redirect_uri)
228            .map_err(|_| OAuthError::new(ErrorCode::InvalidRequest, state.clone()))?;
229
230        if redirect_uri_parsed.scheme() != "https"
231            || redirect_uri_parsed.scheme() == "http"
232                && redirect_uri_parsed.host() != Some(Domain("localhost"))
233        {
234            return Err(OAuthError::new(ErrorCode::InvalidRequest, state));
235        }
236
237        let scope = match request.scope {
238            Some(scope) => {
239                if scope.split_ascii_whitespace().any(|scope| {
240                    !client
241                        .scopes
242                        .iter()
243                        .find(|defined_scope| *defined_scope == scope)
244                        .is_some()
245                }) {
246                    return Err(OAuthError::new(ErrorCode::InvalidScope, state));
247                }
248
249                scope
250            }
251            None => "profile email".to_string(),
252        };
253
254        Ok(VerifiedAuthorizationRequest {
255            client_id: request.client_id.to_string(),
256            code_challenge: request.code_challenge.to_string(),
257            code_challenge_method,
258            redirect_uri,
259            scope,
260            state: request.state.clone(),
261        })
262    }
263
264    async fn perform_authorization(
265        &self,
266        provider: &impl Provider,
267        request: VerifiedAuthorizationRequest,
268    ) -> Result<Self::Response, OAuthError> {
269        let authorization_code = random_secure_string(32);
270        let information = AuthorizationInformation {
271            client_id: request.client_id,
272            redirect_uri: request.redirect_uri,
273            code_challenge: request.code_challenge,
274            code_challenge_method: request.code_challenge_method,
275            scope: request.scope,
276            state: request.state.clone(),
277            is_valid: true,
278        };
279
280        provider
281            .save_authorization_information(authorization_code.clone(), information)
282            .await?;
283        Ok(AuthorizationResponse {
284            code: authorization_code,
285            state: request.state.clone(),
286        })
287    }
288}
289
290#[async_trait(?Send)]
291impl TokenFlow for AuthorizationCodeFlow {
292    type Request = TokenRequest;
293    type VerifiedRequest = VerifiedTokenRequest;
294    type Response = TokenResponse;
295
296    async fn verify_token_request(
297        &self,
298        authenticated_client: Option<Client>,
299        request: Self::Request,
300        provider: &impl Provider,
301    ) -> Result<Self::VerifiedRequest, OAuthError> {
302        if authenticated_client.is_some() && request.client_id.is_some()
303            || authenticated_client.is_none() && request.client_id.is_none()
304        {
305            return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
306        }
307
308        if request.grant_type != GrantType::AuthorizationCode {
309            return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
310        }
311
312        let client = if let Some(client) = authenticated_client {
313            client
314        } else {
315            let client = provider
316                .get_client(&request.client_id.unwrap())
317                .await
318                .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
319            if client.secret.is_some() {
320                return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
321            }
322
323            client
324        };
325
326        if client.redirect_uris.len() > 1 && request.redirect_uri.is_none() {
327            return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
328        }
329
330        let authentication_information = provider
331            .get_authorization_information(&request.code)
332            .await
333            .ok_or(OAuthError::new(ErrorCode::AccessDenied, None))?;
334
335        if !authentication_information.is_valid {
336            return Err(OAuthError::new(ErrorCode::AccessDenied, None));
337        }
338
339        if let Some(redirect_uri) = request.redirect_uri {
340            if authentication_information.redirect_uri != redirect_uri {
341                return Err(OAuthError::new(ErrorCode::AccessDenied, None));
342            }
343        } else if authentication_information.redirect_uri != client.redirect_uris[0].to_string() {
344            return Err(OAuthError::new(ErrorCode::AccessDenied, None));
345        }
346
347        if authentication_information.client_id != client.id {
348            return Err(OAuthError::new(ErrorCode::AccessDenied, None));
349        }
350
351        match authentication_information.code_challenge_method {
352            CodeChallengeMethod::Plain => {
353                if request.code_verifier != authentication_information.code_challenge {
354                    return Err(OAuthError::new(ErrorCode::AccessDenied, None));
355                }
356            }
357            CodeChallengeMethod::S256 => {
358                if &sha256(&request.code_verifier)
359                    != authentication_information.code_challenge.as_bytes()
360                {
361                    return Err(OAuthError::new(ErrorCode::AccessDenied, None));
362                }
363            }
364        }
365
366        provider
367            .remove_authorization_information(&request.code)
368            .await?;
369        Ok(VerifiedTokenRequest {
370            authentication_information,
371        })
372    }
373
374    async fn perform_token_exchange(
375        &self,
376        request: Self::VerifiedRequest,
377    ) -> Result<Self::Response, OAuthError> {
378        Ok(TokenResponse {
379            access_token: random_secure_string(24),
380            token_type: "Bearer".to_string(),
381            expires_in: 3600,
382            scope: request.authentication_information.scope,
383        })
384    }
385}
386
387#[async_trait(?Send)]
388impl TokenFlow for ClientCredentialsFlow {
389    type Request = ClientCredentialsTokenRequest;
390    type VerifiedRequest = VerifiedClientCredentialsTokenRequest;
391    type Response = TokenResponse;
392
393    async fn verify_token_request(
394        &self,
395        authenticated_client: Option<Client>,
396        request: Self::Request,
397        _: &impl Provider,
398    ) -> Result<Self::VerifiedRequest, OAuthError> {
399        if authenticated_client.is_none() {
400            return Err(OAuthError::new(ErrorCode::UnauthorizedClient, None));
401        }
402
403        if request.grant_type != GrantType::ClientCredentials {
404            return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
405        }
406
407        let scope = if let Some(scope) = request.scope {
408            if scope.split_ascii_whitespace().any(|scope| {
409                !authenticated_client
410                    .as_ref()
411                    .unwrap()
412                    .scopes
413                    .iter()
414                    .find(|defined_scope| *defined_scope == scope)
415                    .is_some()
416            }) {
417                return Err(OAuthError::new(ErrorCode::InvalidScope, None));
418            }
419
420            scope
421        } else {
422            "profile email".to_string()
423        };
424
425        Ok(VerifiedClientCredentialsTokenRequest { scope })
426    }
427
428    async fn perform_token_exchange(
429        &self,
430        request: Self::VerifiedRequest,
431    ) -> Result<Self::Response, OAuthError> {
432        Ok(TokenResponse {
433            access_token: random_secure_string(24),
434            token_type: "Bearer".to_string(),
435            expires_in: 3600,
436            scope: request.scope,
437        })
438    }
439}
440
441pub trait HttpRequestDetails {
442    fn get_headers(&self) -> HashMap<String, String>;
443    fn get_form_values(&self) -> HashMap<String, String>;
444}
445
446#[async_trait(?Send)]
447pub trait ClientAuthenticator {
448    async fn authenticate_client(
449        &self,
450        provider: &impl Provider,
451        details: &impl HttpRequestDetails,
452    ) -> Result<Client, OAuthError>;
453}
454
455pub struct ClientSecretBasic;
456
457#[async_trait(?Send)]
458impl ClientAuthenticator for ClientSecretBasic {
459    async fn authenticate_client(
460        &self,
461        provider: &impl Provider,
462        details: &impl HttpRequestDetails,
463    ) -> Result<Client, OAuthError> {
464        let auth_headers = details.get_headers();
465        if let Some(value) = auth_headers.get("Authorization") {
466            let value = decode_base64(value)
467                .map_err(|_| OAuthError::new(ErrorCode::InvalidRequest, None))?;
468            let mut iter = value.split(":").take(2);
469            let client_id = iter
470                .next()
471                .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
472            let client_secret = iter
473                .next()
474                .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
475            match provider.get_client(client_id).await {
476                Some(client) if client.secret.as_deref() == Some(client_secret) => {
477                    if client.token_endpoint_auth_method
478                        == TokenEndpointAuthMethod::ClientSecretBasic
479                    {
480                        Ok(client)
481                    } else {
482                        Err(OAuthError::new(ErrorCode::InvalidRequest, None))
483                    }
484                }
485                Some(_) => return Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
486                None => return Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
487            }
488        } else {
489            Err(OAuthError::new(ErrorCode::InvalidRequest, None))
490        }
491    }
492}
493
494pub struct ClientSecretPost;
495
496#[async_trait(?Send)]
497impl ClientAuthenticator for ClientSecretPost {
498    async fn authenticate_client(
499        &self,
500        provider: &impl Provider,
501        details: &impl HttpRequestDetails,
502    ) -> Result<Client, OAuthError> {
503        let auth_form_values = details.get_form_values();
504        let client_id = auth_form_values
505            .get("client_id")
506            .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
507        let client_secret = auth_form_values
508            .get("client_secret")
509            .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
510        match provider.get_client(client_id).await {
511            Some(client) if client.secret.as_deref() == Some(client_secret) => {
512                if client.token_endpoint_auth_method == TokenEndpointAuthMethod::ClientSecretPost {
513                    Ok(client)
514                } else {
515                    Err(OAuthError::new(ErrorCode::InvalidRequest, None))
516                }
517            }
518            Some(_) => Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
519            None => Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
520        }
521    }
522}