oauth2_test_server/handlers/
authorize.rs1use axum::{
2 extract::{Query, State},
3 response::IntoResponse,
4 response::Redirect,
5};
6use chrono::{Duration, Utc};
7use serde::Deserialize;
8use std::collections::HashSet;
9
10use crate::{
11 crypto::generate_code,
12 models::{AuthorizationCode, Token},
13 store::AppState,
14};
15
16#[derive(Deserialize, Debug)]
17pub struct AuthorizeQuery {
18 pub response_type: String,
19 pub client_id: String,
20 pub redirect_uri: Option<String>,
21 pub scope: Option<String>,
22 pub state: Option<String>,
23 pub response_mode: Option<String>,
24 pub code_challenge: Option<String>,
25 pub code_challenge_method: Option<String>,
26 pub nonce: Option<String>,
27 pub prompt: Option<String>,
28 pub max_age: Option<String>,
29 pub claims: Option<String>,
30 pub ui_locales: Option<String>,
31}
32
33#[derive(Debug, Clone, PartialEq, Default)]
34pub enum Prompt {
35 None,
36 Login,
37 #[default]
38 Consent,
39 SelectAccount,
40}
41
42#[allow(clippy::should_implement_trait)]
43impl Prompt {
44 pub fn from_str(s: &str) -> Option<Self> {
45 match s.to_lowercase().as_str() {
46 "none" => Some(Prompt::None),
47 "login" => Some(Prompt::Login),
48 "consent" => Some(Prompt::Consent),
49 "select_account" => Some(Prompt::SelectAccount),
50 _ => None,
51 }
52 }
53}
54
55#[tracing::instrument(skip(state))]
60pub async fn authorize(
61 State(state): State<AppState>,
62 Query(params): Query<AuthorizeQuery>,
63) -> impl IntoResponse {
64 let client = match state.store.get_client(¶ms.client_id).await {
65 Some(c) => c,
66 None => {
67 return Redirect::to(&format!(
68 "/error?error=invalid_client&state={}",
69 params.state.as_deref().unwrap_or("")
70 ))
71 .into_response();
72 }
73 };
74
75 if state.config.require_state && params.state.is_none() {
76 return Redirect::to(
77 "/error?error=invalid_request&error_description=state_parameter_required",
78 )
79 .into_response();
80 }
81
82 let supported_response_types = [
83 "code",
84 "token",
85 "id_token",
86 "code token",
87 "code id_token",
88 "token id_token",
89 "code token id_token",
90 ];
91 if !supported_response_types.contains(¶ms.response_type.as_str()) {
92 return Redirect::to(&format!(
93 "/error?error=unsupported_response_type&state={}",
94 params.state.as_deref().unwrap_or("")
95 ))
96 .into_response();
97 }
98
99 if let Some(ref prompt) = params.prompt {
100 if let Some(p) = Prompt::from_str(prompt) {
101 match p {
102 Prompt::None => {
103 return Redirect::to(&format!(
104 "/error?error=invalid_request&error_description=prompt=none requires no existing session&state={}",
105 params.state.as_deref().unwrap_or("")
106 ))
107 .into_response();
108 }
109 Prompt::Login | Prompt::Consent | Prompt::SelectAccount => {}
110 }
111 } else {
112 return Redirect::to(&format!(
113 "/error?error=invalid_request&error_description=invalid prompt value&state={}",
114 params.state.as_deref().unwrap_or("")
115 ))
116 .into_response();
117 }
118 }
119
120 if let Some(ref max_age) = params.max_age {
121 if max_age.parse::<i64>().is_err() {
122 return Redirect::to(&format!(
123 "/error?error=invalid_request&error_description=max_age must be an integer&state={}",
124 params.state.as_deref().unwrap_or("")
125 ))
126 .into_response();
127 }
128 }
129
130 if let Some(ref claims) = params.claims {
131 if serde_json::from_str::<serde_json::Value>(claims).is_err() {
132 return Redirect::to(&format!(
133 "/error?error=invalid_request&error_description=invalid claims parameter&state={}",
134 params.state.as_deref().unwrap_or("")
135 ))
136 .into_response();
137 }
138 }
139
140 let redirect_uri = match ¶ms.redirect_uri {
141 Some(uri) => {
142 if !client.redirect_uris.contains(uri) {
143 return Redirect::to(&format!(
144 "/error?error=invalid_request&state={}",
145 params.state.as_deref().unwrap_or("")
146 ))
147 .into_response();
148 }
149 uri.clone()
150 }
151 None => match client.redirect_uris.first() {
152 Some(uri) => uri.clone(),
153 None => {
154 return Redirect::to(&format!(
155 "/error?error=invalid_request&state={}&error_description=no_redirect_uri",
156 params.state.as_deref().unwrap_or("")
157 ))
158 .into_response();
159 }
160 },
161 };
162
163 let code = generate_code();
164
165 let requested_scopes: HashSet<String> = params
166 .scope
167 .clone()
168 .unwrap_or_default()
169 .split_whitespace()
170 .map(|s| s.to_string())
171 .collect();
172 let registered_scopes: HashSet<String> = client
173 .scope
174 .split_whitespace()
175 .map(|s| s.to_string())
176 .collect();
177 let granted_scopes: Vec<String> = requested_scopes
178 .intersection(®istered_scopes)
179 .cloned()
180 .collect();
181 let final_scope = granted_scopes.join(" ");
182
183 let auth_code = AuthorizationCode {
184 code: code.clone(),
185 client_id: params.client_id.clone(),
186 redirect_uri: redirect_uri.clone(),
187 scope: final_scope,
188 expires_at: Utc::now()
189 + Duration::seconds(state.config.authorization_code_expires_in as i64),
190 code_challenge: params.code_challenge.clone(),
191 code_challenge_method: params.code_challenge_method.clone(),
192 user_id: state.config.default_user_id.clone(),
193 nonce: params.nonce.clone(),
194 state: params.state.clone(),
195 };
196
197 state.store.insert_code(code.clone(), auth_code).await;
198
199 let response_mode = params.response_mode.as_deref().unwrap_or("query");
200 let state_param = params.state.as_deref().unwrap_or("");
201
202 match response_mode {
203 "form_post" => {
204 let form_html = format!(
205 r#"<!DOCTYPE html>
206<html>
207<head><title>Redirect</title></head>
208<body>
209<form id="form" method="POST" action="{}">
210<input type="hidden" name="code" value="{}"/>
211<input type="hidden" name="state" value="{}"/>
212</form>
213<script>document.getElementById('form').submit();</script>
214</body>
215</html>"#,
216 redirect_uri, code, state_param
217 );
218 (
219 http::StatusCode::OK,
220 [("Content-Type", "text/html")],
221 form_html,
222 )
223 .into_response()
224 }
225 "fragment" => {
226 let redirect_url = format!("{}?code={}&state={}#", redirect_uri, code, state_param);
227 Redirect::to(&redirect_url).into_response()
228 }
229 _ => {
230 let redirect_url = format!("{}?code={}&state={}", redirect_uri, code, state_param);
231 Redirect::to(&redirect_url).into_response()
232 }
233 }
234}
235
236pub async fn store_token(state: &AppState, token: Token) {
238 let jwt = token.access_token.clone();
239 if let Some(rt) = token.refresh_token.clone() {
240 state.store.insert_refresh_token(rt, token.clone()).await;
241 }
242 state.store.insert_token(jwt, token).await;
243}