1use axum::extract::Extension;
2use axum::http::header::COOKIE;
3use axum::http::{HeaderMap, StatusCode};
4use axum::response::{IntoResponse, Response};
5use axum::{Form, Json};
6use serde::Deserialize;
7use serde_json::json;
8use url::Url;
9
10#[cfg(test)]
11use allowthem_core::applications::CreateApplicationParams;
12use allowthem_core::applications::{Application, BrandingConfig, validate_redirect_uri};
13use allowthem_core::authorization::{
14 generate_authorization_code, hash_authorization_code, validate_scopes,
15};
16use allowthem_core::types::{ClientId, UserId};
17use allowthem_core::{AllowThem, AuthError};
18
19enum OAuthErrorCode {
24 InvalidRequest,
25 AccessDenied,
26 UnsupportedResponseType,
27 InvalidScope,
28 ServerError,
29}
30
31impl OAuthErrorCode {
32 fn as_str(&self) -> &'static str {
33 match self {
34 Self::InvalidRequest => "invalid_request",
35 Self::AccessDenied => "access_denied",
36 Self::UnsupportedResponseType => "unsupported_response_type",
37 Self::InvalidScope => "invalid_scope",
38 Self::ServerError => "server_error",
39 }
40 }
41}
42
43pub enum AuthorizeOutcome {
50 Redirect(Response),
52 ConsentNeeded(Box<ConsentNeededData>),
54}
55
56pub struct ConsentNeededData {
57 pub context: ConsentContext,
58 pub params: ValidatedAuthorize,
59}
60
61#[derive(Deserialize)]
66pub struct AuthorizeParams {
67 pub client_id: Option<ClientId>,
68 pub redirect_uri: Option<String>,
69 pub response_type: Option<String>,
70 pub scope: Option<String>,
71 pub state: Option<String>,
72 pub code_challenge: Option<String>,
73 pub code_challenge_method: Option<String>,
74 pub nonce: Option<String>,
75}
76
77#[derive(Deserialize)]
79pub struct ConsentSubmission {
80 client_id: Option<ClientId>,
81 redirect_uri: Option<String>,
82 response_type: Option<String>,
83 scope: Option<String>,
84 state: Option<String>,
85 code_challenge: Option<String>,
86 code_challenge_method: Option<String>,
87 nonce: Option<String>,
88 consent: String,
89 #[allow(dead_code)]
91 csrf_token: Option<String>,
92}
93
94pub struct ConsentContext {
96 pub branding: BrandingConfig,
97 pub scopes: Vec<String>,
98}
99
100pub struct ValidatedAuthorize {
102 pub application: Application,
103 pub redirect_uri: String,
104 pub scopes: Vec<String>,
105 pub state: String,
106 pub code_challenge: String,
107 pub code_challenge_method: String,
108 pub nonce: Option<String>,
109}
110
111fn success_redirect(redirect_uri: &str, code: &str, state: &str, status: StatusCode) -> Response {
117 let mut url = Url::parse(redirect_uri).expect("redirect_uri was pre-validated");
118 url.query_pairs_mut()
119 .append_pair("code", code)
120 .append_pair("state", state);
121 (status, [("location", url.as_str().to_string())]).into_response()
122}
123
124fn error_redirect(
126 redirect_uri: &str,
127 error: OAuthErrorCode,
128 description: &str,
129 state: &str,
130 status: StatusCode,
131) -> Response {
132 let mut url = Url::parse(redirect_uri).expect("redirect_uri was pre-validated");
133 url.query_pairs_mut()
134 .append_pair("error", error.as_str())
135 .append_pair("error_description", description)
136 .append_pair("state", state);
137 (status, [("location", url.as_str().to_string())]).into_response()
138}
139
140fn display_error(status: StatusCode, message: &str) -> Response {
142 (status, Json(json!({"error": message}))).into_response()
143}
144
145pub async fn resolve_user(
153 ath: &AllowThem,
154 headers: &HeaderMap,
155) -> Result<Option<allowthem_core::User>, AuthError> {
156 let cookie_str = match headers.get(COOKIE).and_then(|v| v.to_str().ok()) {
157 Some(c) => c,
158 None => return Ok(None),
159 };
160
161 let token =
162 match allowthem_core::parse_session_cookie(cookie_str, ath.session_config().cookie_name) {
163 Some(t) => t,
164 None => return Ok(None),
165 };
166
167 let session = match ath
168 .db()
169 .validate_session(&token, ath.session_config().ttl)
170 .await?
171 {
172 Some(s) => s,
173 None => return Ok(None),
174 };
175
176 match ath.db().get_user(session.user_id).await {
177 Ok(user) if user.is_active => Ok(Some(user)),
178 Ok(_) => Ok(None),
179 Err(AuthError::NotFound) => Ok(None),
180 Err(e) => Err(e),
181 }
182}
183
184pub async fn validate_authorize_params(
191 ath: &AllowThem,
192 params: &AuthorizeParams,
193) -> Result<ValidatedAuthorize, Response> {
194 let client_id = params
196 .client_id
197 .as_ref()
198 .ok_or_else(|| display_error(StatusCode::BAD_REQUEST, "missing client_id"))?;
199
200 let application = ath
201 .db()
202 .get_application_by_client_id(client_id)
203 .await
204 .map_err(|e| match e {
205 AuthError::NotFound => display_error(StatusCode::BAD_REQUEST, "unknown client_id"),
206 _ => display_error(StatusCode::INTERNAL_SERVER_ERROR, "internal error"),
207 })?;
208
209 if !application.is_active {
211 return Err(display_error(
212 StatusCode::BAD_REQUEST,
213 "application is inactive",
214 ));
215 }
216
217 let redirect_uri = params.redirect_uri.as_deref().unwrap_or("");
219 if redirect_uri.is_empty() {
220 return Err(display_error(
221 StatusCode::BAD_REQUEST,
222 "missing redirect_uri",
223 ));
224 }
225 let registered = application
226 .redirect_uri_list()
227 .map_err(|_| display_error(StatusCode::INTERNAL_SERVER_ERROR, "internal error"))?;
228 validate_redirect_uri(redirect_uri, ®istered)
229 .map_err(|_| display_error(StatusCode::BAD_REQUEST, "redirect_uri not registered"))?;
230
231 let redirect_uri = redirect_uri.to_string();
233
234 let state = match params.state.as_deref() {
236 Some(s) if !s.is_empty() => s.to_string(),
237 _ => {
238 return Err(error_redirect(
239 &redirect_uri,
240 OAuthErrorCode::InvalidRequest,
241 "missing state parameter",
242 "",
243 StatusCode::FOUND,
244 ));
245 }
246 };
247
248 if params.response_type.as_deref() != Some("code") {
250 return Err(error_redirect(
251 &redirect_uri,
252 OAuthErrorCode::UnsupportedResponseType,
253 "response_type must be code",
254 &state,
255 StatusCode::FOUND,
256 ));
257 }
258
259 let scope_str = params.scope.as_deref().unwrap_or("");
261 let scopes = validate_scopes(scope_str).map_err(|e| {
262 error_redirect(
263 &redirect_uri,
264 OAuthErrorCode::InvalidScope,
265 &e.to_string(),
266 &state,
267 StatusCode::FOUND,
268 )
269 })?;
270
271 let code_challenge = match params.code_challenge.as_deref() {
273 Some(c) if !c.is_empty() => c.to_string(),
274 _ => {
275 return Err(error_redirect(
276 &redirect_uri,
277 OAuthErrorCode::InvalidRequest,
278 "missing code_challenge (PKCE required)",
279 &state,
280 StatusCode::FOUND,
281 ));
282 }
283 };
284 let code_challenge_method = params.code_challenge_method.as_deref().unwrap_or("");
285 if code_challenge_method != "S256" {
286 return Err(error_redirect(
287 &redirect_uri,
288 OAuthErrorCode::InvalidRequest,
289 "code_challenge_method must be S256",
290 &state,
291 StatusCode::FOUND,
292 ));
293 }
294
295 Ok(ValidatedAuthorize {
296 application,
297 redirect_uri,
298 scopes,
299 state,
300 code_challenge,
301 code_challenge_method: "S256".to_string(),
302 nonce: params.nonce.clone(),
303 })
304}
305
306fn build_authorize_query_string(params: &AuthorizeParams) -> String {
312 let mut pairs = url::form_urlencoded::Serializer::new(String::new());
313 if let Some(ref v) = params.client_id {
314 pairs.append_pair("client_id", v.as_str());
315 }
316 if let Some(ref v) = params.redirect_uri {
317 pairs.append_pair("redirect_uri", v);
318 }
319 if let Some(ref v) = params.response_type {
320 pairs.append_pair("response_type", v);
321 }
322 if let Some(ref v) = params.scope {
323 pairs.append_pair("scope", v);
324 }
325 if let Some(ref v) = params.state {
326 pairs.append_pair("state", v);
327 }
328 if let Some(ref v) = params.code_challenge {
329 pairs.append_pair("code_challenge", v);
330 }
331 if let Some(ref v) = params.code_challenge_method {
332 pairs.append_pair("code_challenge_method", v);
333 }
334 if let Some(ref v) = params.nonce {
335 pairs.append_pair("nonce", v);
336 }
337 pairs.finish()
338}
339
340fn login_redirect(params: &AuthorizeParams) -> Response {
342 let full_uri = format!("/oauth/authorize?{}", build_authorize_query_string(params));
343 let encoded: String = url::form_urlencoded::byte_serialize(full_uri.as_bytes()).collect();
344 let mut redirect = format!("/login?next={encoded}");
345 if let Some(ref cid) = params.client_id {
346 redirect.push_str("&client_id=");
347 redirect.push_str(cid.as_str());
348 }
349 (StatusCode::SEE_OTHER, [("location", redirect)]).into_response()
350}
351
352pub async fn issue_code_and_redirect(
354 ath: &AllowThem,
355 validated: &ValidatedAuthorize,
356 user_id: UserId,
357 status: StatusCode,
358) -> Response {
359 let raw_code = generate_authorization_code();
360 let code_hash = hash_authorization_code(&raw_code);
361
362 match ath
363 .db()
364 .create_authorization_code(
365 validated.application.id,
366 user_id,
367 &code_hash,
368 &validated.redirect_uri,
369 &validated.scopes,
370 &validated.code_challenge,
371 &validated.code_challenge_method,
372 validated.nonce.as_deref(),
373 )
374 .await
375 {
376 Ok(_) => success_redirect(&validated.redirect_uri, &raw_code, &validated.state, status),
377 Err(_) => error_redirect(
378 &validated.redirect_uri,
379 OAuthErrorCode::ServerError,
380 "internal error",
381 &validated.state,
382 status,
383 ),
384 }
385}
386
387pub async fn check_authorization(
394 ath: &AllowThem,
395 headers: &HeaderMap,
396 params: &AuthorizeParams,
397) -> AuthorizeOutcome {
398 let validated = match validate_authorize_params(ath, params).await {
399 Ok(v) => v,
400 Err(resp) => return AuthorizeOutcome::Redirect(resp),
401 };
402
403 let user = match resolve_user(ath, headers).await {
405 Ok(Some(u)) => u,
406 Ok(None) => return AuthorizeOutcome::Redirect(login_redirect(params)),
407 Err(_) => {
408 return AuthorizeOutcome::Redirect(error_redirect(
409 &validated.redirect_uri,
410 OAuthErrorCode::ServerError,
411 "internal error",
412 &validated.state,
413 StatusCode::FOUND,
414 ));
415 }
416 };
417
418 let needs_consent = if validated.application.is_trusted {
420 false
421 } else {
422 match ath
423 .db()
424 .has_sufficient_consent(user.id, validated.application.id, &validated.scopes)
425 .await
426 {
427 Ok(has) => !has,
428 Err(_) => {
429 return AuthorizeOutcome::Redirect(error_redirect(
430 &validated.redirect_uri,
431 OAuthErrorCode::ServerError,
432 "internal error",
433 &validated.state,
434 StatusCode::FOUND,
435 ));
436 }
437 }
438 };
439
440 if needs_consent {
441 let context = ConsentContext {
442 branding: validated.application.branding(),
443 scopes: validated.scopes.clone(),
444 };
445 return AuthorizeOutcome::ConsentNeeded(Box::new(ConsentNeededData {
446 context,
447 params: validated,
448 }));
449 }
450
451 AuthorizeOutcome::Redirect(
453 issue_code_and_redirect(ath, &validated, user.id, StatusCode::FOUND).await,
454 )
455}
456
457pub async fn authorize_post(
458 Extension(ath): Extension<AllowThem>,
459 headers: HeaderMap,
460 Form(form): Form<ConsentSubmission>,
461) -> Response {
462 let params = AuthorizeParams {
464 client_id: form.client_id,
465 redirect_uri: form.redirect_uri,
466 response_type: form.response_type,
467 scope: form.scope,
468 state: form.state,
469 code_challenge: form.code_challenge,
470 code_challenge_method: form.code_challenge_method,
471 nonce: form.nonce,
472 };
473 let validated = match validate_authorize_params(&ath, ¶ms).await {
474 Ok(v) => v,
475 Err(resp) => return resp,
476 };
477
478 let user = match resolve_user(&ath, &headers).await {
480 Ok(Some(u)) => u,
481 Ok(None) => return login_redirect(¶ms),
482 Err(_) => {
483 return error_redirect(
484 &validated.redirect_uri,
485 OAuthErrorCode::ServerError,
486 "internal error",
487 &validated.state,
488 StatusCode::SEE_OTHER,
489 );
490 }
491 };
492
493 if form.consent != "approve" {
495 return error_redirect(
496 &validated.redirect_uri,
497 OAuthErrorCode::AccessDenied,
498 "user denied consent",
499 &validated.state,
500 StatusCode::SEE_OTHER,
501 );
502 }
503
504 if ath
506 .db()
507 .upsert_consent(user.id, validated.application.id, &validated.scopes)
508 .await
509 .is_err()
510 {
511 return error_redirect(
512 &validated.redirect_uri,
513 OAuthErrorCode::ServerError,
514 "internal error",
515 &validated.state,
516 StatusCode::SEE_OTHER,
517 );
518 }
519
520 issue_code_and_redirect(&ath, &validated, user.id, StatusCode::SEE_OTHER).await
522}
523
524#[cfg(test)]
533mod tests {
534 use super::*;
535 use allowthem_core::handle::AllowThemBuilder;
536 use allowthem_core::types::{ClientType, Email};
537 use axum::Router;
538 use axum::body::Body;
539 use axum::http::Request;
540 use axum::routing::post;
541 use tower::ServiceExt;
542
543 async fn test_ath() -> AllowThem {
544 AllowThemBuilder::new("sqlite::memory:")
545 .cookie_secure(false)
546 .build()
547 .await
548 .unwrap()
549 }
550
551 async fn setup_application(ath: &AllowThem) -> Application {
552 let email = Email::new("admin@example.com".into()).unwrap();
553 let user = ath
554 .db()
555 .create_user(email, "password123", None, None)
556 .await
557 .unwrap();
558
559 let (app, _) = ath
560 .db()
561 .create_application(CreateApplicationParams {
562 name: "TestApp".to_string(),
563 client_type: ClientType::Confidential,
564 redirect_uris: vec!["https://example.com/callback".to_string()],
565 is_trusted: false,
566 created_by: Some(user.id),
567 logo_url: None,
568 primary_color: None,
569 accent_hex: None,
570 accent_ink: None,
571 forced_mode: None,
572 font_css_url: None,
573 font_family: None,
574 splash_text: None,
575 splash_image_url: None,
576 splash_primitive: None,
577 splash_url: None,
578 shader_cell_scale: None,
579 })
580 .await
581 .unwrap();
582 app
583 }
584
585 fn authorize_params(app: &Application) -> AuthorizeParams {
586 AuthorizeParams {
587 client_id: Some(app.client_id.clone()),
588 redirect_uri: Some("https://example.com/callback".into()),
589 response_type: Some("code".into()),
590 scope: Some("openid profile".into()),
591 state: Some("xyz".into()),
592 code_challenge: Some("abc123".into()),
593 code_challenge_method: Some("S256".into()),
594 nonce: None,
595 }
596 }
597
598 fn expect_redirect(outcome: AuthorizeOutcome) -> Response {
600 match outcome {
601 AuthorizeOutcome::Redirect(resp) => resp,
602 AuthorizeOutcome::ConsentNeeded(_) => {
603 panic!("expected Redirect, got ConsentNeeded")
604 }
605 }
606 }
607
608 async fn read_body(resp: axum::http::Response<Body>) -> serde_json::Value {
609 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
610 .await
611 .unwrap();
612 serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null)
613 }
614
615 async fn create_session(
617 ath: &AllowThem,
618 email: &str,
619 ) -> (allowthem_core::types::UserId, String) {
620 let email = Email::new(email.into()).unwrap();
621 let user = ath
622 .db()
623 .create_user(email, "password123", None, None)
624 .await
625 .unwrap();
626 let token = allowthem_core::generate_token();
627 let hash = allowthem_core::hash_token(&token);
628 let expires = chrono::Utc::now() + chrono::Duration::hours(24);
629 ath.db()
630 .create_session(user.id, hash, None, None, expires)
631 .await
632 .unwrap();
633 let cookie = format!("allowthem_session={}", token.as_str());
634 (user.id, cookie)
635 }
636
637 fn headers_with_cookie(cookie: &str) -> HeaderMap {
638 let mut headers = HeaderMap::new();
639 headers.insert("cookie", cookie.parse().unwrap());
640 headers
641 }
642
643 #[tokio::test]
646 async fn missing_client_id_returns_400() {
647 let ath = test_ath().await;
648 let params = AuthorizeParams {
649 client_id: None,
650 redirect_uri: Some("x".into()),
651 response_type: Some("code".into()),
652 scope: Some("openid".into()),
653 state: Some("s".into()),
654 code_challenge: Some("c".into()),
655 code_challenge_method: Some("S256".into()),
656 nonce: None,
657 };
658 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
659 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
660 let body = read_body(resp).await;
661 assert_eq!(body["error"], "missing client_id");
662 }
663
664 #[tokio::test]
665 async fn unknown_client_id_returns_400() {
666 let ath = test_ath().await;
667 let params = AuthorizeParams {
668 client_id: serde_json::from_value(serde_json::json!("ath_nonexistent")).ok(),
669 redirect_uri: Some("x".into()),
670 response_type: Some("code".into()),
671 scope: Some("openid".into()),
672 state: Some("s".into()),
673 code_challenge: Some("c".into()),
674 code_challenge_method: Some("S256".into()),
675 nonce: None,
676 };
677 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
678 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
679 let body = read_body(resp).await;
680 assert_eq!(body["error"], "unknown client_id");
681 }
682
683 #[tokio::test]
684 async fn unregistered_redirect_uri_returns_400() {
685 let ath = test_ath().await;
686 let application = setup_application(&ath).await;
687 let params = AuthorizeParams {
688 client_id: Some(application.client_id.clone()),
689 redirect_uri: Some("https://evil.example.com/callback".into()),
690 response_type: Some("code".into()),
691 scope: Some("openid".into()),
692 state: Some("s".into()),
693 code_challenge: Some("c".into()),
694 code_challenge_method: Some("S256".into()),
695 nonce: None,
696 };
697 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
698 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
699 let body = read_body(resp).await;
700 assert_eq!(body["error"], "redirect_uri not registered");
701 }
702
703 #[tokio::test]
706 async fn missing_state_redirects_with_error() {
707 let ath = test_ath().await;
708 let application = setup_application(&ath).await;
709 let params = AuthorizeParams {
710 client_id: Some(application.client_id.clone()),
711 redirect_uri: Some("https://example.com/callback".into()),
712 response_type: Some("code".into()),
713 scope: Some("openid".into()),
714 state: None,
715 code_challenge: Some("c".into()),
716 code_challenge_method: Some("S256".into()),
717 nonce: None,
718 };
719 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
720 assert_eq!(resp.status(), StatusCode::FOUND);
721 let location = resp.headers().get("location").unwrap().to_str().unwrap();
722 assert!(location.contains("error=invalid_request"));
723 }
724
725 #[tokio::test]
726 async fn bad_response_type_redirects_with_error() {
727 let ath = test_ath().await;
728 let application = setup_application(&ath).await;
729 let params = AuthorizeParams {
730 client_id: Some(application.client_id.clone()),
731 redirect_uri: Some("https://example.com/callback".into()),
732 response_type: Some("token".into()),
733 scope: Some("openid".into()),
734 state: Some("s".into()),
735 code_challenge: Some("c".into()),
736 code_challenge_method: Some("S256".into()),
737 nonce: None,
738 };
739 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
740 assert_eq!(resp.status(), StatusCode::FOUND);
741 let location = resp.headers().get("location").unwrap().to_str().unwrap();
742 assert!(location.contains("error=unsupported_response_type"));
743 assert!(location.contains("state=s"));
744 }
745
746 #[tokio::test]
747 async fn invalid_scope_redirects_with_error() {
748 let ath = test_ath().await;
749 let application = setup_application(&ath).await;
750 let params = AuthorizeParams {
751 client_id: Some(application.client_id.clone()),
752 redirect_uri: Some("https://example.com/callback".into()),
753 response_type: Some("code".into()),
754 scope: Some("profile".into()),
755 state: Some("s".into()),
756 code_challenge: Some("c".into()),
757 code_challenge_method: Some("S256".into()),
758 nonce: None,
759 };
760 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
761 assert_eq!(resp.status(), StatusCode::FOUND);
762 let location = resp.headers().get("location").unwrap().to_str().unwrap();
763 assert!(location.contains("error=invalid_scope"));
764 }
765
766 #[tokio::test]
767 async fn missing_pkce_redirects_with_error() {
768 let ath = test_ath().await;
769 let application = setup_application(&ath).await;
770 let params = AuthorizeParams {
771 client_id: Some(application.client_id.clone()),
772 redirect_uri: Some("https://example.com/callback".into()),
773 response_type: Some("code".into()),
774 scope: Some("openid".into()),
775 state: Some("s".into()),
776 code_challenge: None,
777 code_challenge_method: None,
778 nonce: None,
779 };
780 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
781 assert_eq!(resp.status(), StatusCode::FOUND);
782 let location = resp.headers().get("location").unwrap().to_str().unwrap();
783 assert!(location.contains("error=invalid_request"));
784 assert!(location.contains("PKCE"));
785 }
786
787 #[tokio::test]
790 async fn unauthenticated_redirects_to_login() {
791 let ath = test_ath().await;
792 let application = setup_application(&ath).await;
793 let params = authorize_params(&application);
794 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
795 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
796 let location = resp.headers().get("location").unwrap().to_str().unwrap();
797 assert!(location.starts_with("/login?next="));
798 assert!(location.contains("oauth%2Fauthorize"));
799 }
800
801 #[tokio::test]
804 async fn trusted_app_skips_consent_and_redirects_with_code() {
805 let ath = test_ath().await;
806 let (_, cookie) = create_session(&ath, "trusted@example.com").await;
807 let headers = headers_with_cookie(&cookie);
808
809 let (trusted_app, _) = ath
810 .db()
811 .create_application(CreateApplicationParams {
812 name: "TrustedApp".to_string(),
813 client_type: ClientType::Confidential,
814 redirect_uris: vec!["https://trusted.example.com/callback".to_string()],
815 is_trusted: true,
816 created_by: None,
817 logo_url: None,
818 primary_color: None,
819 accent_hex: None,
820 accent_ink: None,
821 forced_mode: None,
822 font_css_url: None,
823 font_family: None,
824 splash_text: None,
825 splash_image_url: None,
826 splash_primitive: None,
827 splash_url: None,
828 shader_cell_scale: None,
829 })
830 .await
831 .unwrap();
832
833 let params = AuthorizeParams {
834 client_id: Some(trusted_app.client_id.clone()),
835 redirect_uri: Some("https://trusted.example.com/callback".into()),
836 response_type: Some("code".into()),
837 scope: Some("openid profile".into()),
838 state: Some("xyz".into()),
839 code_challenge: Some("abc123".into()),
840 code_challenge_method: Some("S256".into()),
841 nonce: None,
842 };
843
844 let resp = expect_redirect(check_authorization(&ath, &headers, ¶ms).await);
845 assert_eq!(resp.status(), StatusCode::FOUND);
846 let location = resp.headers().get("location").unwrap().to_str().unwrap();
847 assert!(location.contains("code="));
848 assert!(location.contains("state=xyz"));
849 assert!(location.starts_with("https://trusted.example.com/callback"));
850 }
851
852 #[tokio::test]
855 async fn untrusted_app_without_consent_returns_consent_needed() {
856 let ath = test_ath().await;
857 let (_, cookie) = create_session(&ath, "consent@example.com").await;
858 let headers = headers_with_cookie(&cookie);
859 let application = setup_application(&ath).await;
860 let params = authorize_params(&application);
861
862 let outcome = check_authorization(&ath, &headers, ¶ms).await;
863 match outcome {
864 AuthorizeOutcome::ConsentNeeded(data) => {
865 assert_eq!(data.context.branding.application_name, "TestApp");
866 assert_eq!(data.context.scopes, vec!["openid", "profile"]);
867 }
868 AuthorizeOutcome::Redirect(_) => panic!("expected ConsentNeeded, got Redirect"),
869 }
870 }
871
872 #[tokio::test]
875 async fn inactive_application_returns_400() {
876 let ath = test_ath().await;
877 let application = setup_application(&ath).await;
878
879 sqlx::query("UPDATE allowthem_applications SET is_active = 0 WHERE id = ?")
880 .bind(application.id)
881 .execute(ath.db().pool())
882 .await
883 .unwrap();
884
885 let params = AuthorizeParams {
886 client_id: Some(application.client_id.clone()),
887 redirect_uri: Some("https://example.com/callback".into()),
888 response_type: Some("code".into()),
889 scope: Some("openid".into()),
890 state: Some("s".into()),
891 code_challenge: Some("c".into()),
892 code_challenge_method: Some("S256".into()),
893 nonce: None,
894 };
895 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
896 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
897 let body = read_body(resp).await;
898 assert_eq!(body["error"], "application is inactive");
899 }
900
901 #[tokio::test]
904 async fn wrong_pkce_method_redirects_with_error() {
905 let ath = test_ath().await;
906 let application = setup_application(&ath).await;
907 let params = AuthorizeParams {
908 client_id: Some(application.client_id.clone()),
909 redirect_uri: Some("https://example.com/callback".into()),
910 response_type: Some("code".into()),
911 scope: Some("openid".into()),
912 state: Some("s".into()),
913 code_challenge: Some("c".into()),
914 code_challenge_method: Some("plain".into()),
915 nonce: None,
916 };
917 let resp = expect_redirect(check_authorization(&ath, &HeaderMap::new(), ¶ms).await);
918 assert_eq!(resp.status(), StatusCode::FOUND);
919 let location = resp.headers().get("location").unwrap().to_str().unwrap();
920 assert!(location.contains("error=invalid_request"));
921 assert!(location.contains("state=s"));
922 }
923
924 #[tokio::test]
927 async fn existing_consent_skips_consent_screen() {
928 let ath = test_ath().await;
929 let (user_id, cookie) = create_session(&ath, "existing_consent@example.com").await;
930 let headers = headers_with_cookie(&cookie);
931 let application = setup_application(&ath).await;
932
933 ath.db()
934 .upsert_consent(
935 user_id,
936 application.id,
937 &["openid".to_string(), "profile".to_string()],
938 )
939 .await
940 .unwrap();
941
942 let params = authorize_params(&application);
943 let resp = expect_redirect(check_authorization(&ath, &headers, ¶ms).await);
944 assert_eq!(resp.status(), StatusCode::FOUND);
945 let location = resp.headers().get("location").unwrap().to_str().unwrap();
946 assert!(location.contains("code="));
947 assert!(location.contains("state=xyz"));
948 }
949
950 fn post_app(ath: AllowThem) -> Router {
953 Router::new()
954 .route("/oauth/authorize", post(authorize_post))
955 .layer(axum::middleware::from_fn_with_state(
956 ath,
957 crate::cors::inject_ath_into_extensions,
958 ))
959 }
960
961 #[tokio::test]
962 async fn post_approve_creates_code_and_redirects_303() {
963 let ath = test_ath().await;
964 let app = post_app(ath.clone());
965 let (_, cookie) = create_session(&ath, "post_approve@example.com").await;
966 let application = setup_application(&ath).await;
967
968 let body = url::form_urlencoded::Serializer::new(String::new())
969 .append_pair("client_id", application.client_id.as_str())
970 .append_pair("redirect_uri", "https://example.com/callback")
971 .append_pair("response_type", "code")
972 .append_pair("scope", "openid profile")
973 .append_pair("state", "mystate")
974 .append_pair("code_challenge", "mychallenge")
975 .append_pair("code_challenge_method", "S256")
976 .append_pair("consent", "approve")
977 .finish();
978
979 let req = Request::builder()
980 .method("POST")
981 .uri("/oauth/authorize")
982 .header("cookie", &cookie)
983 .header("content-type", "application/x-www-form-urlencoded")
984 .body(Body::from(body))
985 .unwrap();
986 let resp = app.oneshot(req).await.unwrap();
987 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
988 let location = resp.headers().get("location").unwrap().to_str().unwrap();
989 assert!(location.starts_with("https://example.com/callback"));
990 assert!(location.contains("code="));
991 assert!(location.contains("state=mystate"));
992 }
993
994 #[tokio::test]
995 async fn post_deny_redirects_with_access_denied_303() {
996 let ath = test_ath().await;
997 let app = post_app(ath.clone());
998 let (_, cookie) = create_session(&ath, "post_deny@example.com").await;
999 let application = setup_application(&ath).await;
1000
1001 let body = url::form_urlencoded::Serializer::new(String::new())
1002 .append_pair("client_id", application.client_id.as_str())
1003 .append_pair("redirect_uri", "https://example.com/callback")
1004 .append_pair("response_type", "code")
1005 .append_pair("scope", "openid profile")
1006 .append_pair("state", "mystate")
1007 .append_pair("code_challenge", "mychallenge")
1008 .append_pair("code_challenge_method", "S256")
1009 .append_pair("consent", "deny")
1010 .finish();
1011
1012 let req = Request::builder()
1013 .method("POST")
1014 .uri("/oauth/authorize")
1015 .header("cookie", &cookie)
1016 .header("content-type", "application/x-www-form-urlencoded")
1017 .body(Body::from(body))
1018 .unwrap();
1019 let resp = app.oneshot(req).await.unwrap();
1020 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
1021 let location = resp.headers().get("location").unwrap().to_str().unwrap();
1022 assert!(location.contains("error=access_denied"));
1023 assert!(location.contains("state=mystate"));
1024 }
1025
1026 #[tokio::test]
1027 async fn post_unauthenticated_redirects_to_login() {
1028 let ath = test_ath().await;
1029 let app = post_app(ath.clone());
1030 let application = setup_application(&ath).await;
1031
1032 let body = url::form_urlencoded::Serializer::new(String::new())
1033 .append_pair("client_id", application.client_id.as_str())
1034 .append_pair("redirect_uri", "https://example.com/callback")
1035 .append_pair("response_type", "code")
1036 .append_pair("scope", "openid")
1037 .append_pair("state", "s")
1038 .append_pair("code_challenge", "c")
1039 .append_pair("code_challenge_method", "S256")
1040 .append_pair("consent", "approve")
1041 .finish();
1042
1043 let req = Request::builder()
1044 .method("POST")
1045 .uri("/oauth/authorize")
1046 .header("content-type", "application/x-www-form-urlencoded")
1047 .body(Body::from(body))
1048 .unwrap();
1049 let resp = app.oneshot(req).await.unwrap();
1050 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
1051 let location = resp.headers().get("location").unwrap().to_str().unwrap();
1052 assert!(location.starts_with("/login?next="));
1053 }
1054
1055 #[tokio::test]
1056 async fn post_with_invalid_client_id_returns_400() {
1057 let ath = test_ath().await;
1058 let app = post_app(ath.clone());
1059 let (_, cookie) = create_session(&ath, "post_revalidate@example.com").await;
1060
1061 let body = url::form_urlencoded::Serializer::new(String::new())
1062 .append_pair("client_id", "ath_nonexistent")
1063 .append_pair("redirect_uri", "https://example.com/callback")
1064 .append_pair("response_type", "code")
1065 .append_pair("scope", "openid")
1066 .append_pair("state", "s")
1067 .append_pair("code_challenge", "c")
1068 .append_pair("code_challenge_method", "S256")
1069 .append_pair("consent", "approve")
1070 .finish();
1071
1072 let req = Request::builder()
1073 .method("POST")
1074 .uri("/oauth/authorize")
1075 .header("cookie", &cookie)
1076 .header("content-type", "application/x-www-form-urlencoded")
1077 .body(Body::from(body))
1078 .unwrap();
1079 let resp = app.oneshot(req).await.unwrap();
1080 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
1081 let body = read_body(resp).await;
1082 assert_eq!(body["error"], "unknown client_id");
1083 }
1084}