Skip to main content

oauth2_test_server/handlers/
authorize.rs

1use axum::{
2    extract::{Query, State},
3    response::IntoResponse,
4    response::Redirect,
5};
6use chrono::{Duration, Utc};
7use serde::Deserialize;
8use std::collections::HashSet;
9
10use crate::{
11    crypto::generate_code,
12    models::{AuthorizationCode, Token},
13    store::AppState,
14};
15
16#[derive(Deserialize, Debug)]
17pub struct AuthorizeQuery {
18    pub response_type: String,
19    pub client_id: String,
20    pub redirect_uri: Option<String>,
21    pub scope: Option<String>,
22    pub state: Option<String>,
23    pub response_mode: Option<String>,
24    pub code_challenge: Option<String>,
25    pub code_challenge_method: Option<String>,
26    pub nonce: Option<String>,
27    pub prompt: Option<String>,
28    pub max_age: Option<String>,
29    pub claims: Option<String>,
30    pub ui_locales: Option<String>,
31}
32
33#[derive(Debug, Clone, PartialEq, Default)]
34pub enum Prompt {
35    None,
36    Login,
37    #[default]
38    Consent,
39    SelectAccount,
40}
41
42#[allow(clippy::should_implement_trait)]
43impl Prompt {
44    pub fn from_str(s: &str) -> Option<Self> {
45        match s.to_lowercase().as_str() {
46            "none" => Some(Prompt::None),
47            "login" => Some(Prompt::Login),
48            "consent" => Some(Prompt::Consent),
49            "select_account" => Some(Prompt::SelectAccount),
50            _ => None,
51        }
52    }
53}
54
55/// `GET /authorize` — OAuth2 authorization endpoint (authorization code flow).
56///
57/// In this test server, consent is auto-granted. The `default_user_id` from
58/// [`IssuerConfig`] is used as the authenticated user.
59#[tracing::instrument(skip(state))]
60pub async fn authorize(
61    State(state): State<AppState>,
62    Query(params): Query<AuthorizeQuery>,
63) -> impl IntoResponse {
64    let client = match state.store.get_client(&params.client_id).await {
65        Some(c) => c,
66        None => {
67            return Redirect::to(&format!(
68                "/error?error=invalid_client&state={}",
69                params.state.as_deref().unwrap_or("")
70            ))
71            .into_response();
72        }
73    };
74
75    if state.config.require_state && params.state.is_none() {
76        return Redirect::to(
77            "/error?error=invalid_request&error_description=state_parameter_required",
78        )
79        .into_response();
80    }
81
82    let supported_response_types = [
83        "code",
84        "token",
85        "id_token",
86        "code token",
87        "code id_token",
88        "token id_token",
89        "code token id_token",
90    ];
91    if !supported_response_types.contains(&params.response_type.as_str()) {
92        return Redirect::to(&format!(
93            "/error?error=unsupported_response_type&state={}",
94            params.state.as_deref().unwrap_or("")
95        ))
96        .into_response();
97    }
98
99    if let Some(ref prompt) = params.prompt {
100        if let Some(p) = Prompt::from_str(prompt) {
101            match p {
102                Prompt::None => {
103                    return Redirect::to(&format!(
104                        "/error?error=invalid_request&error_description=prompt=none requires no existing session&state={}",
105                        params.state.as_deref().unwrap_or("")
106                    ))
107                    .into_response();
108                }
109                Prompt::Login | Prompt::Consent | Prompt::SelectAccount => {}
110            }
111        } else {
112            return Redirect::to(&format!(
113                "/error?error=invalid_request&error_description=invalid prompt value&state={}",
114                params.state.as_deref().unwrap_or("")
115            ))
116            .into_response();
117        }
118    }
119
120    if let Some(ref max_age) = params.max_age {
121        if max_age.parse::<i64>().is_err() {
122            return Redirect::to(&format!(
123                "/error?error=invalid_request&error_description=max_age must be an integer&state={}",
124                params.state.as_deref().unwrap_or("")
125            ))
126            .into_response();
127        }
128    }
129
130    if let Some(ref claims) = params.claims {
131        if serde_json::from_str::<serde_json::Value>(claims).is_err() {
132            return Redirect::to(&format!(
133                "/error?error=invalid_request&error_description=invalid claims parameter&state={}",
134                params.state.as_deref().unwrap_or("")
135            ))
136            .into_response();
137        }
138    }
139
140    let redirect_uri = match &params.redirect_uri {
141        Some(uri) => {
142            if !client.redirect_uris.contains(uri) {
143                return Redirect::to(&format!(
144                    "/error?error=invalid_request&state={}",
145                    params.state.as_deref().unwrap_or("")
146                ))
147                .into_response();
148            }
149            uri.clone()
150        }
151        None => match client.redirect_uris.first() {
152            Some(uri) => uri.clone(),
153            None => {
154                return Redirect::to(&format!(
155                    "/error?error=invalid_request&state={}&error_description=no_redirect_uri",
156                    params.state.as_deref().unwrap_or("")
157                ))
158                .into_response();
159            }
160        },
161    };
162
163    let code = generate_code();
164
165    let requested_scopes: HashSet<String> = params
166        .scope
167        .clone()
168        .unwrap_or_default()
169        .split_whitespace()
170        .map(|s| s.to_string())
171        .collect();
172    let registered_scopes: HashSet<String> = client
173        .scope
174        .split_whitespace()
175        .map(|s| s.to_string())
176        .collect();
177    let granted_scopes: Vec<String> = requested_scopes
178        .intersection(&registered_scopes)
179        .cloned()
180        .collect();
181    let final_scope = granted_scopes.join(" ");
182
183    let auth_code = AuthorizationCode {
184        code: code.clone(),
185        client_id: params.client_id.clone(),
186        redirect_uri: redirect_uri.clone(),
187        scope: final_scope,
188        expires_at: Utc::now()
189            + Duration::seconds(state.config.authorization_code_expires_in as i64),
190        code_challenge: params.code_challenge.clone(),
191        code_challenge_method: params.code_challenge_method.clone(),
192        user_id: state.config.default_user_id.clone(),
193        nonce: params.nonce.clone(),
194        state: params.state.clone(),
195    };
196
197    state.store.insert_code(code.clone(), auth_code).await;
198
199    let response_mode = params.response_mode.as_deref().unwrap_or("query");
200    let state_param = params.state.as_deref().unwrap_or("");
201
202    match response_mode {
203        "form_post" => {
204            let form_html = format!(
205                r#"<!DOCTYPE html>
206<html>
207<head><title>Redirect</title></head>
208<body>
209<form id="form" method="POST" action="{}">
210<input type="hidden" name="code" value="{}"/>
211<input type="hidden" name="state" value="{}"/>
212</form>
213<script>document.getElementById('form').submit();</script>
214</body>
215</html>"#,
216                redirect_uri, code, state_param
217            );
218            (
219                http::StatusCode::OK,
220                [("Content-Type", "text/html")],
221                form_html,
222            )
223                .into_response()
224        }
225        "fragment" => {
226            let redirect_url = format!("{}?code={}&state={}#", redirect_uri, code, state_param);
227            Redirect::to(&redirect_url).into_response()
228        }
229        _ => {
230            let redirect_url = format!("{}?code={}&state={}", redirect_uri, code, state_param);
231            Redirect::to(&redirect_url).into_response()
232        }
233    }
234}
235
236/// Helper used by the testkit to store a pre-built `Token` directly.
237pub async fn store_token(state: &AppState, token: Token) {
238    let jwt = token.access_token.clone();
239    if let Some(rt) = token.refresh_token.clone() {
240        state.store.insert_refresh_token(rt, token.clone()).await;
241    }
242    state.store.insert_token(jwt, token).await;
243}