1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use url::{Host::Domain, Url};
5
6use crate::{
7 client::{Client, GrantType, ResponseType, TokenEndpointAuthMethod},
8 crypto::{decode_base64, random_secure_string, sha256},
9 error::{ErrorCode, OAuthError},
10};
11
12#[derive(Debug, Clone, Copy)]
13pub enum CodeChallengeMethod {
14 Plain,
15 S256,
16}
17
18pub struct AuthorizationRequest {
19 pub response_type: ResponseType,
20 pub client_id: String,
21 pub code_challenge: String,
22 pub code_challenge_method: Option<CodeChallengeMethod>,
23 pub redirect_uri: Option<String>,
24 pub scope: Option<String>,
25 pub state: Option<String>,
26}
27
28pub struct TokenRequest {
29 pub grant_type: GrantType,
30 pub code: String,
31 pub redirect_uri: Option<String>,
32 pub client_id: Option<String>,
33 pub code_verifier: String,
34}
35
36pub struct ClientCredentialsTokenRequest {
37 pub grant_type: GrantType,
38 pub scope: Option<String>,
39}
40
41#[derive(Debug)]
42pub struct AuthorizationResponse {
43 pub code: String,
44 pub state: Option<String>,
45}
46
47#[derive(Debug)]
48pub struct TokenResponse {
49 access_token: String,
50 token_type: String,
51 expires_in: u64,
52 scope: String,
53}
54
55pub struct VerifiedAuthorizationRequest {
56 client_id: String,
57 code_challenge: String,
58 code_challenge_method: CodeChallengeMethod,
59 redirect_uri: String,
60 scope: String,
61 state: Option<String>,
62}
63
64pub struct VerifiedTokenRequest {
65 authentication_information: AuthorizationInformation,
66}
67
68pub struct VerifiedClientCredentialsTokenRequest {
69 pub scope: String,
70}
71
72#[derive(Debug, Clone)]
73pub struct AuthorizationInformation {
74 client_id: String,
75 redirect_uri: String,
76 code_challenge: String,
77 code_challenge_method: CodeChallengeMethod,
78 scope: String,
79 state: Option<String>,
80 is_valid: bool,
81}
82
83#[derive(Debug)]
84pub struct SigninInformation {
85 client_name: String,
86 client_uri: String,
87 logo_uri: String,
88 scopes: Vec<String>,
89 contacts: Vec<String>,
90 tos_uri: String,
91 policy_uri: String,
92}
93
94#[async_trait(?Send)]
95pub trait Provider
96where
97 Self: Sized,
98{
99 async fn store_client(&self, client: Client) -> Result<(), OAuthError>;
100 async fn get_client(&self, client_id: &str) -> Option<Client>;
101 async fn save_authorization_information(
102 &self,
103 id: String,
104 information: AuthorizationInformation,
105 ) -> Result<(), OAuthError>;
106 async fn get_authorization_information(&self, id: &str) -> Option<AuthorizationInformation>;
107 async fn remove_authorization_information(&self, id: &str) -> Result<(), OAuthError>;
108
109 async fn verify_authorization_request<T: AuthorizationFlow>(
110 &self,
111 request: AuthorizationRequest,
112 flow: T,
113 ) -> Result<(VerifiedAuthorizationRequest, SigninInformation), OAuthError> {
114 let client = self
115 .get_client(&request.client_id)
116 .await
117 .ok_or(OAuthError::new(
118 ErrorCode::InvalidRequest,
119 request.state.clone(),
120 ))?;
121 let request = flow.verify_authorization(&client, request).await?;
122 Ok((
123 request,
124 SigninInformation {
125 client_name: client.name,
126 client_uri: client.uri,
127 logo_uri: client.logo_uri,
128 scopes: client.scopes,
129 contacts: client.contacts,
130 tos_uri: client.tos_uri,
131 policy_uri: client.policy_uri,
132 },
133 ))
134 }
135
136 async fn authorize<T: AuthorizationFlow>(
137 &self,
138 request: VerifiedAuthorizationRequest,
139 flow: T,
140 ) -> Result<T::Response, OAuthError> {
141 flow.perform_authorization(self, request).await
142 }
143
144 async fn get_token<T: TokenFlow>(
145 &self,
146 authenticated_client: Option<Client>,
147 request: T::Request,
148 flow: T,
149 ) -> Result<T::Response, OAuthError> {
150 let request = flow
151 .verify_token_request(authenticated_client, request, self)
152 .await?;
153 flow.perform_token_exchange(request).await
154 }
155}
156
157#[async_trait(?Send)]
158pub trait AuthorizationFlow {
159 type Response;
160
161 async fn verify_authorization(
162 &self,
163 client: &Client,
164 request: AuthorizationRequest,
165 ) -> Result<VerifiedAuthorizationRequest, OAuthError>;
166
167 async fn perform_authorization(
168 &self,
169 provider: &impl Provider,
170 request: VerifiedAuthorizationRequest,
171 ) -> Result<Self::Response, OAuthError>;
172}
173
174#[async_trait(?Send)]
175pub trait TokenFlow {
176 type Request;
177 type VerifiedRequest;
178 type Response;
179
180 async fn verify_token_request(
181 &self,
182 authenticated_client: Option<Client>,
183 request: Self::Request,
184 provider: &impl Provider,
185 ) -> Result<Self::VerifiedRequest, OAuthError>;
186
187 async fn perform_token_exchange(
188 &self,
189 request: Self::VerifiedRequest,
190 ) -> Result<Self::Response, OAuthError>;
191}
192
193pub struct AuthorizationCodeFlow;
194pub struct ClientCredentialsFlow;
195
196#[async_trait(?Send)]
197impl AuthorizationFlow for AuthorizationCodeFlow {
198 type Response = AuthorizationResponse;
199
200 async fn verify_authorization(
201 &self,
202 client: &Client,
203 request: AuthorizationRequest,
204 ) -> Result<VerifiedAuthorizationRequest, OAuthError> {
205 let state = request.state.clone();
206 if request.response_type != ResponseType::Code {
207 return Err(OAuthError::new(ErrorCode::InvalidRequest, state));
208 }
209
210 let code_challenge_method = if let Some(method) = request.code_challenge_method {
211 method
212 } else {
213 CodeChallengeMethod::Plain
214 };
215
216 let redirect_uri = match request.redirect_uri {
217 Some(ref redirect_uri) if !client.redirect_uris.contains(redirect_uri) => {
218 return Err(OAuthError::new(ErrorCode::InvalidRequest, state))
219 }
220 None if client.redirect_uris.len() > 1 => {
221 return Err(OAuthError::new(ErrorCode::InvalidRequest, state))
222 }
223 Some(redirect_uri) => redirect_uri,
224 None => client.redirect_uris[0].to_string(),
225 };
226
227 let redirect_uri_parsed = Url::parse(&redirect_uri)
228 .map_err(|_| OAuthError::new(ErrorCode::InvalidRequest, state.clone()))?;
229
230 if redirect_uri_parsed.scheme() != "https"
231 || redirect_uri_parsed.scheme() == "http"
232 && redirect_uri_parsed.host() != Some(Domain("localhost"))
233 {
234 return Err(OAuthError::new(ErrorCode::InvalidRequest, state));
235 }
236
237 let scope = match request.scope {
238 Some(scope) => {
239 if scope.split_ascii_whitespace().any(|scope| {
240 !client
241 .scopes
242 .iter()
243 .find(|defined_scope| *defined_scope == scope)
244 .is_some()
245 }) {
246 return Err(OAuthError::new(ErrorCode::InvalidScope, state));
247 }
248
249 scope
250 }
251 None => "profile email".to_string(),
252 };
253
254 Ok(VerifiedAuthorizationRequest {
255 client_id: request.client_id.to_string(),
256 code_challenge: request.code_challenge.to_string(),
257 code_challenge_method,
258 redirect_uri,
259 scope,
260 state: request.state.clone(),
261 })
262 }
263
264 async fn perform_authorization(
265 &self,
266 provider: &impl Provider,
267 request: VerifiedAuthorizationRequest,
268 ) -> Result<Self::Response, OAuthError> {
269 let authorization_code = random_secure_string(32);
270 let information = AuthorizationInformation {
271 client_id: request.client_id,
272 redirect_uri: request.redirect_uri,
273 code_challenge: request.code_challenge,
274 code_challenge_method: request.code_challenge_method,
275 scope: request.scope,
276 state: request.state.clone(),
277 is_valid: true,
278 };
279
280 provider
281 .save_authorization_information(authorization_code.clone(), information)
282 .await?;
283 Ok(AuthorizationResponse {
284 code: authorization_code,
285 state: request.state.clone(),
286 })
287 }
288}
289
290#[async_trait(?Send)]
291impl TokenFlow for AuthorizationCodeFlow {
292 type Request = TokenRequest;
293 type VerifiedRequest = VerifiedTokenRequest;
294 type Response = TokenResponse;
295
296 async fn verify_token_request(
297 &self,
298 authenticated_client: Option<Client>,
299 request: Self::Request,
300 provider: &impl Provider,
301 ) -> Result<Self::VerifiedRequest, OAuthError> {
302 if authenticated_client.is_some() && request.client_id.is_some()
303 || authenticated_client.is_none() && request.client_id.is_none()
304 {
305 return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
306 }
307
308 if request.grant_type != GrantType::AuthorizationCode {
309 return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
310 }
311
312 let client = if let Some(client) = authenticated_client {
313 client
314 } else {
315 let client = provider
316 .get_client(&request.client_id.unwrap())
317 .await
318 .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
319 if client.secret.is_some() {
320 return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
321 }
322
323 client
324 };
325
326 if client.redirect_uris.len() > 1 && request.redirect_uri.is_none() {
327 return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
328 }
329
330 let authentication_information = provider
331 .get_authorization_information(&request.code)
332 .await
333 .ok_or(OAuthError::new(ErrorCode::AccessDenied, None))?;
334
335 if !authentication_information.is_valid {
336 return Err(OAuthError::new(ErrorCode::AccessDenied, None));
337 }
338
339 if let Some(redirect_uri) = request.redirect_uri {
340 if authentication_information.redirect_uri != redirect_uri {
341 return Err(OAuthError::new(ErrorCode::AccessDenied, None));
342 }
343 } else if authentication_information.redirect_uri != client.redirect_uris[0].to_string() {
344 return Err(OAuthError::new(ErrorCode::AccessDenied, None));
345 }
346
347 if authentication_information.client_id != client.id {
348 return Err(OAuthError::new(ErrorCode::AccessDenied, None));
349 }
350
351 match authentication_information.code_challenge_method {
352 CodeChallengeMethod::Plain => {
353 if request.code_verifier != authentication_information.code_challenge {
354 return Err(OAuthError::new(ErrorCode::AccessDenied, None));
355 }
356 }
357 CodeChallengeMethod::S256 => {
358 if &sha256(&request.code_verifier)
359 != authentication_information.code_challenge.as_bytes()
360 {
361 return Err(OAuthError::new(ErrorCode::AccessDenied, None));
362 }
363 }
364 }
365
366 provider
367 .remove_authorization_information(&request.code)
368 .await?;
369 Ok(VerifiedTokenRequest {
370 authentication_information,
371 })
372 }
373
374 async fn perform_token_exchange(
375 &self,
376 request: Self::VerifiedRequest,
377 ) -> Result<Self::Response, OAuthError> {
378 Ok(TokenResponse {
379 access_token: random_secure_string(24),
380 token_type: "Bearer".to_string(),
381 expires_in: 3600,
382 scope: request.authentication_information.scope,
383 })
384 }
385}
386
387#[async_trait(?Send)]
388impl TokenFlow for ClientCredentialsFlow {
389 type Request = ClientCredentialsTokenRequest;
390 type VerifiedRequest = VerifiedClientCredentialsTokenRequest;
391 type Response = TokenResponse;
392
393 async fn verify_token_request(
394 &self,
395 authenticated_client: Option<Client>,
396 request: Self::Request,
397 _: &impl Provider,
398 ) -> Result<Self::VerifiedRequest, OAuthError> {
399 if authenticated_client.is_none() {
400 return Err(OAuthError::new(ErrorCode::UnauthorizedClient, None));
401 }
402
403 if request.grant_type != GrantType::ClientCredentials {
404 return Err(OAuthError::new(ErrorCode::InvalidRequest, None));
405 }
406
407 let scope = if let Some(scope) = request.scope {
408 if scope.split_ascii_whitespace().any(|scope| {
409 !authenticated_client
410 .as_ref()
411 .unwrap()
412 .scopes
413 .iter()
414 .find(|defined_scope| *defined_scope == scope)
415 .is_some()
416 }) {
417 return Err(OAuthError::new(ErrorCode::InvalidScope, None));
418 }
419
420 scope
421 } else {
422 "profile email".to_string()
423 };
424
425 Ok(VerifiedClientCredentialsTokenRequest { scope })
426 }
427
428 async fn perform_token_exchange(
429 &self,
430 request: Self::VerifiedRequest,
431 ) -> Result<Self::Response, OAuthError> {
432 Ok(TokenResponse {
433 access_token: random_secure_string(24),
434 token_type: "Bearer".to_string(),
435 expires_in: 3600,
436 scope: request.scope,
437 })
438 }
439}
440
441pub trait HttpRequestDetails {
442 fn get_headers(&self) -> HashMap<String, String>;
443 fn get_form_values(&self) -> HashMap<String, String>;
444}
445
446#[async_trait(?Send)]
447pub trait ClientAuthenticator {
448 async fn authenticate_client(
449 &self,
450 provider: &impl Provider,
451 details: &impl HttpRequestDetails,
452 ) -> Result<Client, OAuthError>;
453}
454
455pub struct ClientSecretBasic;
456
457#[async_trait(?Send)]
458impl ClientAuthenticator for ClientSecretBasic {
459 async fn authenticate_client(
460 &self,
461 provider: &impl Provider,
462 details: &impl HttpRequestDetails,
463 ) -> Result<Client, OAuthError> {
464 let auth_headers = details.get_headers();
465 if let Some(value) = auth_headers.get("Authorization") {
466 let value = decode_base64(value)
467 .map_err(|_| OAuthError::new(ErrorCode::InvalidRequest, None))?;
468 let mut iter = value.split(":").take(2);
469 let client_id = iter
470 .next()
471 .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
472 let client_secret = iter
473 .next()
474 .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
475 match provider.get_client(client_id).await {
476 Some(client) if client.secret.as_deref() == Some(client_secret) => {
477 if client.token_endpoint_auth_method
478 == TokenEndpointAuthMethod::ClientSecretBasic
479 {
480 Ok(client)
481 } else {
482 Err(OAuthError::new(ErrorCode::InvalidRequest, None))
483 }
484 }
485 Some(_) => return Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
486 None => return Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
487 }
488 } else {
489 Err(OAuthError::new(ErrorCode::InvalidRequest, None))
490 }
491 }
492}
493
494pub struct ClientSecretPost;
495
496#[async_trait(?Send)]
497impl ClientAuthenticator for ClientSecretPost {
498 async fn authenticate_client(
499 &self,
500 provider: &impl Provider,
501 details: &impl HttpRequestDetails,
502 ) -> Result<Client, OAuthError> {
503 let auth_form_values = details.get_form_values();
504 let client_id = auth_form_values
505 .get("client_id")
506 .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
507 let client_secret = auth_form_values
508 .get("client_secret")
509 .ok_or(OAuthError::new(ErrorCode::InvalidRequest, None))?;
510 match provider.get_client(client_id).await {
511 Some(client) if client.secret.as_deref() == Some(client_secret) => {
512 if client.token_endpoint_auth_method == TokenEndpointAuthMethod::ClientSecretPost {
513 Ok(client)
514 } else {
515 Err(OAuthError::new(ErrorCode::InvalidRequest, None))
516 }
517 }
518 Some(_) => Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
519 None => Err(OAuthError::new(ErrorCode::InvalidRequest, None)),
520 }
521 }
522}