Skip to main content

oauth2_test_server/handlers/
token.rs

1use axum::{
2    extract::{Form, State},
3    http::HeaderMap,
4    response::IntoResponse,
5    Json,
6};
7use base64::{engine::general_purpose, Engine};
8use chrono::{Duration, Utc};
9use serde::Deserialize;
10use serde_json::json;
11use sha2::Digest;
12use std::collections::HashSet;
13
14use crate::{
15    crypto::{
16        calculate_at_hash, calculate_c_hash, generate_token_string, issue_id_token, issue_jwt,
17    },
18    error::OauthError,
19    models::Token,
20    store::AppState,
21};
22
23#[derive(Deserialize)]
24pub struct TokenRequest {
25    pub grant_type: String,
26    pub code: Option<String>,
27    pub _redirect_uri: Option<String>,
28    pub client_id: Option<String>,
29    pub _client_secret: Option<String>,
30    pub refresh_token: Option<String>,
31    pub code_verifier: Option<String>,
32    pub scope: Option<String>,
33}
34
35#[tracing::instrument(skip(state, form, _headers))]
36pub async fn token_endpoint(
37    State(state): State<AppState>,
38    _headers: HeaderMap,
39    Form(form): Form<TokenRequest>,
40) -> Result<impl IntoResponse, OauthError> {
41    match form.grant_type.as_str() {
42        "authorization_code" => handle_authorization_code(state, form).await,
43        "refresh_token" => handle_refresh_token(state, form).await,
44        "client_credentials" => handle_client_credentials(state, form).await,
45        _ => Err(OauthError::UnsupportedGrantType),
46    }
47}
48
49async fn handle_authorization_code(
50    state: AppState,
51    form: TokenRequest,
52) -> Result<Json<serde_json::Value>, OauthError> {
53    let code = form.code.as_deref().unwrap_or("");
54    let code_obj = state
55        .store
56        .remove_code(code)
57        .await
58        .ok_or(OauthError::InvalidGrant)?;
59
60    if code_obj.expires_at < Utc::now() {
61        return Err(OauthError::InvalidGrant);
62    }
63
64    if let (Some(challenge), Some(verifier)) = (&code_obj.code_challenge, &form.code_verifier) {
65        let method = code_obj.code_challenge_method.as_deref().unwrap_or("plain");
66        let computed = if method == "S256" {
67            general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(verifier.as_bytes()))
68        } else {
69            verifier.clone()
70        };
71        if computed != *challenge {
72            return Err(OauthError::InvalidGrant);
73        }
74    }
75
76    let refresh_token = generate_token_string();
77
78    let jwt = issue_jwt(
79        state.issuer(),
80        &code_obj.client_id,
81        &code_obj.user_id,
82        &code_obj.scope,
83        state.config.access_token_expires_in as i64,
84        &state.keys,
85    )
86    .map_err(|_| OauthError::ServerError)?;
87
88    let scopes: HashSet<&str> = code_obj.scope.split_whitespace().collect();
89    let include_id_token = scopes.contains("openid");
90
91    let id_token = if include_id_token {
92        let at_hash = calculate_at_hash(&jwt);
93        let c_hash = calculate_c_hash(code);
94
95        let user_claims = json!({
96            "name": code_obj.user_id.clone(),
97        });
98
99        let id_token = issue_id_token(
100            state.issuer(),
101            &code_obj.client_id,
102            &code_obj.user_id,
103            code_obj.nonce.as_deref(),
104            Some(&at_hash),
105            Some(&c_hash),
106            state.config.access_token_expires_in as i64,
107            user_claims,
108            &state.keys,
109        )
110        .map_err(|_| OauthError::ServerError)?;
111
112        Some(id_token)
113    } else {
114        None
115    };
116
117    let token = Token {
118        access_token: jwt.clone(),
119        refresh_token: Some(refresh_token.clone()),
120        client_id: code_obj.client_id.clone(),
121        scope: code_obj.scope.clone(),
122        expires_at: Utc::now() + Duration::seconds(state.config.access_token_expires_in as i64),
123        user_id: code_obj.user_id.clone(),
124        revoked: false,
125    };
126
127    state.store.insert_token(jwt.clone(), token.clone()).await;
128    state
129        .store
130        .insert_refresh_token(refresh_token.clone(), token)
131        .await;
132
133    let mut response = json!({
134        "access_token": jwt,
135        "token_type": "Bearer",
136        "expires_in": state.config.access_token_expires_in,
137        "refresh_token": refresh_token,
138        "scope": code_obj.scope
139    });
140
141    if let Some(id) = id_token {
142        response["id_token"] = serde_json::Value::String(id);
143    }
144
145    if let Some(ref state) = code_obj.state {
146        response["state"] = serde_json::Value::String(state.clone());
147    }
148
149    Ok(Json(response))
150}
151
152async fn handle_refresh_token(
153    state: AppState,
154    form: TokenRequest,
155) -> Result<Json<serde_json::Value>, OauthError> {
156    let rt = form.refresh_token.as_deref().unwrap_or("");
157    let mut token = state
158        .store
159        .get_refresh_token(rt)
160        .await
161        .ok_or(OauthError::InvalidGrant)?;
162
163    if token.revoked {
164        return Err(OauthError::InvalidGrant);
165    }
166
167    let new_access_token = issue_jwt(
168        state.issuer(),
169        &token.client_id,
170        &token.user_id,
171        &token.scope,
172        state.config.access_token_expires_in as i64,
173        &state.keys,
174    )
175    .map_err(|_| OauthError::ServerError)?;
176
177    let new_refresh_token = generate_token_string();
178
179    let new_token = Token {
180        access_token: new_access_token.clone(),
181        refresh_token: Some(new_refresh_token.clone()),
182        client_id: token.client_id.clone(),
183        scope: token.scope.clone(),
184        expires_at: Utc::now() + Duration::seconds(state.config.access_token_expires_in as i64),
185        user_id: token.user_id.clone(),
186        revoked: false,
187    };
188
189    state
190        .store
191        .insert_token(new_access_token.clone(), new_token.clone())
192        .await;
193    state
194        .store
195        .insert_refresh_token(new_refresh_token.clone(), new_token)
196        .await;
197
198    token.revoked = true;
199    state.store.update_refresh_token(rt, token.clone()).await;
200
201    Ok(Json(json!({
202        "access_token": new_access_token,
203        "token_type": "Bearer",
204        "expires_in": state.config.access_token_expires_in,
205        "refresh_token": new_refresh_token,
206        "scope": token.scope
207    })))
208}
209
210async fn handle_client_credentials(
211    state: AppState,
212    form: TokenRequest,
213) -> Result<Json<serde_json::Value>, OauthError> {
214    let client_id = form.client_id.as_deref().unwrap_or("");
215    let client = state
216        .store
217        .get_client(client_id)
218        .await
219        .ok_or(OauthError::InvalidClient)?;
220
221    let requested_scopes: HashSet<String> = form
222        .scope
223        .as_deref()
224        .unwrap_or("")
225        .split_whitespace()
226        .map(|s| s.to_string())
227        .collect();
228
229    if let Some(requested_scope) = form.scope.as_deref() {
230        if let Err(e) = state.config.validate_scope(requested_scope) {
231            return Err(OauthError::InvalidScope(e));
232        }
233
234        let client_scopes: HashSet<_> = client.scope.split_whitespace().collect();
235        let requested_scopes_set: HashSet<_> = requested_scope.split_whitespace().collect();
236
237        let not_permitted: Vec<_> = requested_scopes_set
238            .difference(&client_scopes)
239            .cloned()
240            .collect();
241
242        if !not_permitted.is_empty() {
243            return Err(OauthError::InvalidScope(format!(
244                "Client not authorized for scopes: {}",
245                not_permitted.join(" ")
246            )));
247        }
248    }
249
250    let registered_scopes: HashSet<String> = client
251        .scope
252        .split_whitespace()
253        .map(|s| s.to_string())
254        .collect();
255
256    let granted_scopes: Vec<String> = requested_scopes
257        .intersection(&registered_scopes)
258        .cloned()
259        .collect();
260
261    if granted_scopes.is_empty() && !requested_scopes.is_empty() {
262        return Err(OauthError::InvalidScope(
263            "Requested scopes not allowed for this client".to_string(),
264        ));
265    }
266
267    let final_scope = if requested_scopes.is_empty() {
268        client.scope.clone()
269    } else {
270        granted_scopes.join(" ")
271    };
272
273    let access_token = issue_jwt(
274        state.issuer(),
275        client_id,
276        "client",
277        &final_scope,
278        state.config.access_token_expires_in as i64,
279        &state.keys,
280    )
281    .map_err(|_| OauthError::ServerError)?;
282
283    Ok(Json(json!({
284        "access_token": access_token,
285        "token_type": "Bearer",
286        "expires_in": state.config.access_token_expires_in,
287        "scope": final_scope
288    })))
289}