oauth2_test_server/handlers/
token.rs1use axum::{
2 extract::{Form, State},
3 http::HeaderMap,
4 response::IntoResponse,
5 Json,
6};
7use base64::{engine::general_purpose, Engine};
8use chrono::{Duration, Utc};
9use serde::Deserialize;
10use serde_json::json;
11use sha2::Digest;
12use std::collections::HashSet;
13
14use crate::{
15 crypto::{
16 calculate_at_hash, calculate_c_hash, generate_token_string, issue_id_token, issue_jwt,
17 },
18 error::OauthError,
19 models::Token,
20 store::AppState,
21};
22
23#[derive(Deserialize)]
24pub struct TokenRequest {
25 pub grant_type: String,
26 pub code: Option<String>,
27 pub _redirect_uri: Option<String>,
28 pub client_id: Option<String>,
29 pub _client_secret: Option<String>,
30 pub refresh_token: Option<String>,
31 pub code_verifier: Option<String>,
32 pub scope: Option<String>,
33}
34
35#[tracing::instrument(skip(state, form, _headers))]
36pub async fn token_endpoint(
37 State(state): State<AppState>,
38 _headers: HeaderMap,
39 Form(form): Form<TokenRequest>,
40) -> Result<impl IntoResponse, OauthError> {
41 match form.grant_type.as_str() {
42 "authorization_code" => handle_authorization_code(state, form).await,
43 "refresh_token" => handle_refresh_token(state, form).await,
44 "client_credentials" => handle_client_credentials(state, form).await,
45 _ => Err(OauthError::UnsupportedGrantType),
46 }
47}
48
49async fn handle_authorization_code(
50 state: AppState,
51 form: TokenRequest,
52) -> Result<Json<serde_json::Value>, OauthError> {
53 let code = form.code.as_deref().unwrap_or("");
54 let code_obj = state
55 .store
56 .remove_code(code)
57 .await
58 .ok_or(OauthError::InvalidGrant)?;
59
60 if code_obj.expires_at < Utc::now() {
61 return Err(OauthError::InvalidGrant);
62 }
63
64 if let (Some(challenge), Some(verifier)) = (&code_obj.code_challenge, &form.code_verifier) {
65 let method = code_obj.code_challenge_method.as_deref().unwrap_or("plain");
66 let computed = if method == "S256" {
67 general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(verifier.as_bytes()))
68 } else {
69 verifier.clone()
70 };
71 if computed != *challenge {
72 return Err(OauthError::InvalidGrant);
73 }
74 }
75
76 let refresh_token = generate_token_string();
77
78 let jwt = issue_jwt(
79 state.issuer(),
80 &code_obj.client_id,
81 &code_obj.user_id,
82 &code_obj.scope,
83 state.config.access_token_expires_in as i64,
84 &state.keys,
85 )
86 .map_err(|_| OauthError::ServerError)?;
87
88 let scopes: HashSet<&str> = code_obj.scope.split_whitespace().collect();
89 let include_id_token = scopes.contains("openid");
90
91 let id_token = if include_id_token {
92 let at_hash = calculate_at_hash(&jwt);
93 let c_hash = calculate_c_hash(code);
94
95 let user_claims = json!({
96 "name": code_obj.user_id.clone(),
97 });
98
99 let id_token = issue_id_token(
100 state.issuer(),
101 &code_obj.client_id,
102 &code_obj.user_id,
103 code_obj.nonce.as_deref(),
104 Some(&at_hash),
105 Some(&c_hash),
106 state.config.access_token_expires_in as i64,
107 user_claims,
108 &state.keys,
109 )
110 .map_err(|_| OauthError::ServerError)?;
111
112 Some(id_token)
113 } else {
114 None
115 };
116
117 let token = Token {
118 access_token: jwt.clone(),
119 refresh_token: Some(refresh_token.clone()),
120 client_id: code_obj.client_id.clone(),
121 scope: code_obj.scope.clone(),
122 expires_at: Utc::now() + Duration::seconds(state.config.access_token_expires_in as i64),
123 user_id: code_obj.user_id.clone(),
124 revoked: false,
125 };
126
127 state.store.insert_token(jwt.clone(), token.clone()).await;
128 state
129 .store
130 .insert_refresh_token(refresh_token.clone(), token)
131 .await;
132
133 let mut response = json!({
134 "access_token": jwt,
135 "token_type": "Bearer",
136 "expires_in": state.config.access_token_expires_in,
137 "refresh_token": refresh_token,
138 "scope": code_obj.scope
139 });
140
141 if let Some(id) = id_token {
142 response["id_token"] = serde_json::Value::String(id);
143 }
144
145 if let Some(ref state) = code_obj.state {
146 response["state"] = serde_json::Value::String(state.clone());
147 }
148
149 Ok(Json(response))
150}
151
152async fn handle_refresh_token(
153 state: AppState,
154 form: TokenRequest,
155) -> Result<Json<serde_json::Value>, OauthError> {
156 let rt = form.refresh_token.as_deref().unwrap_or("");
157 let mut token = state
158 .store
159 .get_refresh_token(rt)
160 .await
161 .ok_or(OauthError::InvalidGrant)?;
162
163 if token.revoked {
164 return Err(OauthError::InvalidGrant);
165 }
166
167 let new_access_token = issue_jwt(
168 state.issuer(),
169 &token.client_id,
170 &token.user_id,
171 &token.scope,
172 state.config.access_token_expires_in as i64,
173 &state.keys,
174 )
175 .map_err(|_| OauthError::ServerError)?;
176
177 let new_refresh_token = generate_token_string();
178
179 let new_token = Token {
180 access_token: new_access_token.clone(),
181 refresh_token: Some(new_refresh_token.clone()),
182 client_id: token.client_id.clone(),
183 scope: token.scope.clone(),
184 expires_at: Utc::now() + Duration::seconds(state.config.access_token_expires_in as i64),
185 user_id: token.user_id.clone(),
186 revoked: false,
187 };
188
189 state
190 .store
191 .insert_token(new_access_token.clone(), new_token.clone())
192 .await;
193 state
194 .store
195 .insert_refresh_token(new_refresh_token.clone(), new_token)
196 .await;
197
198 token.revoked = true;
199 state.store.update_refresh_token(rt, token.clone()).await;
200
201 Ok(Json(json!({
202 "access_token": new_access_token,
203 "token_type": "Bearer",
204 "expires_in": state.config.access_token_expires_in,
205 "refresh_token": new_refresh_token,
206 "scope": token.scope
207 })))
208}
209
210async fn handle_client_credentials(
211 state: AppState,
212 form: TokenRequest,
213) -> Result<Json<serde_json::Value>, OauthError> {
214 let client_id = form.client_id.as_deref().unwrap_or("");
215 let client = state
216 .store
217 .get_client(client_id)
218 .await
219 .ok_or(OauthError::InvalidClient)?;
220
221 let requested_scopes: HashSet<String> = form
222 .scope
223 .as_deref()
224 .unwrap_or("")
225 .split_whitespace()
226 .map(|s| s.to_string())
227 .collect();
228
229 if let Some(requested_scope) = form.scope.as_deref() {
230 if let Err(e) = state.config.validate_scope(requested_scope) {
231 return Err(OauthError::InvalidScope(e));
232 }
233
234 let client_scopes: HashSet<_> = client.scope.split_whitespace().collect();
235 let requested_scopes_set: HashSet<_> = requested_scope.split_whitespace().collect();
236
237 let not_permitted: Vec<_> = requested_scopes_set
238 .difference(&client_scopes)
239 .cloned()
240 .collect();
241
242 if !not_permitted.is_empty() {
243 return Err(OauthError::InvalidScope(format!(
244 "Client not authorized for scopes: {}",
245 not_permitted.join(" ")
246 )));
247 }
248 }
249
250 let registered_scopes: HashSet<String> = client
251 .scope
252 .split_whitespace()
253 .map(|s| s.to_string())
254 .collect();
255
256 let granted_scopes: Vec<String> = requested_scopes
257 .intersection(®istered_scopes)
258 .cloned()
259 .collect();
260
261 if granted_scopes.is_empty() && !requested_scopes.is_empty() {
262 return Err(OauthError::InvalidScope(
263 "Requested scopes not allowed for this client".to_string(),
264 ));
265 }
266
267 let final_scope = if requested_scopes.is_empty() {
268 client.scope.clone()
269 } else {
270 granted_scopes.join(" ")
271 };
272
273 let access_token = issue_jwt(
274 state.issuer(),
275 client_id,
276 "client",
277 &final_scope,
278 state.config.access_token_expires_in as i64,
279 &state.keys,
280 )
281 .map_err(|_| OauthError::ServerError)?;
282
283 Ok(Json(json!({
284 "access_token": access_token,
285 "token_type": "Bearer",
286 "expires_in": state.config.access_token_expires_in,
287 "scope": final_scope
288 })))
289}