1use base64::{engine::general_purpose, Engine};
2use jsonwebtoken::{encode, Algorithm, Header};
3use rand::Rng;
4use reqwest::{Client as HttpClient, StatusCode};
5use serde_json::{json, Value};
6use sha2::{Digest, Sha256};
7use tokio::task::JoinError;
8use url::Url;
9use uuid::Uuid;
10
11#[derive(Debug, Clone)]
12pub struct PkcePair {
13 pub code_verifier: String,
14 pub code_challenge: String,
15}
16
17#[derive(Debug, Default)]
18pub struct AuthorizeParams {
19 pub response_type: &'static str,
20 pub redirect_uri: String,
21 pub scope: String,
22 pub state: Option<String>,
23 pub response_mode: Option<String>,
24 pub pkce: Option<PkcePair>,
25 pub nonce: Option<String>,
26 pub prompt: Option<String>,
27 pub max_age: Option<String>,
28 pub claims: Option<String>,
29 pub ui_locales: Option<String>,
30}
31
32impl AuthorizeParams {
33 pub fn new() -> Self {
34 Self {
35 response_type: "code",
36 redirect_uri: "http://localhost/cb".to_string(),
37 scope: "openid".to_string(),
38 state: Some(Uuid::new_v4().to_string()),
39 response_mode: None,
40 pkce: None,
41 nonce: None,
42 prompt: None,
43 max_age: None,
44 claims: None,
45 ui_locales: None,
46 }
47 }
48
49 pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
50 self.redirect_uri = uri.into();
51 self
52 }
53
54 pub fn scope(mut self, scope: impl Into<String>) -> Self {
55 self.scope = scope.into();
56 self
57 }
58
59 pub fn state(mut self, state: impl Into<String>) -> Self {
60 self.state = Some(state.into());
61 self
62 }
63
64 pub fn no_state(mut self) -> Self {
65 self.state = None;
66 self
67 }
68
69 pub fn response_mode(mut self, mode: impl Into<String>) -> Self {
70 self.response_mode = Some(mode.into());
71 self
72 }
73
74 pub fn pkce(mut self, pkce: PkcePair) -> Self {
75 self.pkce = Some(pkce);
76 self
77 }
78
79 pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
80 self.nonce = Some(nonce.into());
81 self
82 }
83
84 pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
85 self.prompt = Some(prompt.into());
86 self
87 }
88
89 pub fn max_age(mut self, max_age: impl Into<String>) -> Self {
90 self.max_age = Some(max_age.into());
91 self
92 }
93
94 pub fn claims(mut self, claims: impl Into<String>) -> Self {
95 self.claims = Some(claims.into());
96 self
97 }
98
99 pub fn ui_locales(mut self, locales: impl Into<String>) -> Self {
100 self.ui_locales = Some(locales.into());
101 self
102 }
103}
104
105#[derive(Debug, Clone)]
106pub struct OauthEndpoints {
107 pub oauth_server: String,
108 pub discovery: String,
109 pub authorize: String,
110 pub token: String,
111 pub register: String,
112 pub introspect: String,
113 pub revoke: String,
114 pub userinfo: String,
115 pub jwks: String,
116 pub device_code: String,
117 pub device_token: String,
118}
119
120pub struct OAuthTestServer {
142 state: AppState,
143 pub base_url: url::Url,
144 pub endpoints: OauthEndpoints,
145 pub http: HttpClient,
146 _handle: tokio::task::JoinHandle<()>,
147}
148
149impl OAuthTestServer {
150 pub async fn start() -> Self {
151 let config = IssuerConfig {
152 port: 0,
153 ..Default::default()
154 };
155 Self::start_with_config(config).await
156 }
157
158 pub async fn start_with_config(config: IssuerConfig) -> Self {
159 let mut state = AppState::new(config.clone());
161 let (addr, handle) = state.clone().start().await;
162 let base_url: Url = format!("http://{addr}").parse().unwrap();
163 state.base_url = base_url.to_string().trim_end_matches("/").to_string();
164 let endpoints: OauthEndpoints = OauthEndpoints {
165 oauth_server: base_url.clone().to_string(),
166 discovery: format!("{base_url}.well-known/openid-configuration"),
167 register: format!("{base_url}register"),
168 authorize: format!("{base_url}authorize"),
169 token: format!("{base_url}token"),
170 introspect: format!("{base_url}introspect"),
171 revoke: format!("{base_url}revoke"),
172 userinfo: format!("{base_url}userinfo"),
173 jwks: format!("{base_url}.well-known/jwks.json"),
174 device_code: format!("{base_url}device/code"),
175 device_token: format!("{base_url}device/token"),
176 };
177
178 Self {
179 state,
180 base_url,
181 endpoints,
182 http: HttpClient::builder()
183 .redirect(reqwest::redirect::Policy::none())
184 .build()
185 .unwrap(),
186 _handle: handle,
187 }
188 }
189
190 pub async fn wait_for_shutdown(self) -> Result<(), JoinError> {
191 self._handle.await
192 }
193
194 pub async fn register_client(&self, metadata: serde_json::Value) -> Client {
195 self.state
196 .register_client(metadata)
197 .await
198 .expect("client registration failed")
199 }
200
201 pub async fn register_client_with_secret(&self, metadata: Value, force_secret: bool) -> Client {
202 let mut meta = metadata;
203 if let Some(obj) = meta.as_object_mut() {
204 obj.insert(
205 "generate_client_secret_for_dcr".to_string(),
206 json!(force_secret),
207 );
208 }
209 self.register_client(meta).await
210 }
211
212 pub fn generate_jwt(&self, client: &Client, options: JwtOptions) -> String {
213 self.state
214 .generate_jwt(client, options)
215 .expect("JWT generation failed")
216 }
217
218 pub async fn generate_token(&self, client: &Client, options: JwtOptions) -> Token {
219 self.state
220 .generate_token(client, options)
221 .await
222 .expect("Token generation failed")
223 }
224
225 pub async fn clients(&self) -> Vec<Client> {
226 self.state.store.get_all_clients().await
227 }
228
229 pub async fn codes(&self) -> Vec<AuthorizationCode> {
230 self.state.store.get_all_codes().await
231 }
232
233 pub async fn tokens(&self) -> Vec<Token> {
234 self.state.store.get_all_tokens().await
235 }
236
237 pub async fn refresh_tokens(&self) -> Vec<Token> {
238 self.state.store.get_all_refresh_tokens().await
239 }
240
241 pub async fn clear_clients(&self) {
242 self.state.store.clear_clients().await;
243 }
244
245 pub async fn clear_codes(&self) {
246 self.state.store.clear_codes().await;
247 }
248
249 pub async fn clear_tokens(&self) {
250 self.state.store.clear_tokens().await;
251 }
252
253 pub async fn clear_refresh_tokens(&self) {
254 self.state.store.clear_refresh_tokens().await;
255 }
256
257 pub async fn clear_device_codes(&self) {
258 self.state.store.clear_device_codes().await;
259 }
260
261 pub async fn clear_all(&self) {
262 self.state.store.clear_all().await;
263 }
264
265 pub async fn approve_device_code(&self, device_code: &str, user_id: &str) {
266 self.state
267 .approve_device_code(device_code, user_id)
268 .await
269 .expect("device code not found");
270 }
271
272 pub fn state(&self) -> &AppState {
273 &self.state
274 }
275
276 pub fn jwt_options(&self) -> JwtOptionsBuilder {
277 JwtOptionsBuilder::default()
278 }
279
280 pub fn pkce_pair(&self) -> PkcePair {
281 let verifier_bytes: [u8; 32] = rand::thread_rng().r#gen();
282 let code_verifier = general_purpose::URL_SAFE_NO_PAD.encode(verifier_bytes);
283 let challenge =
284 general_purpose::URL_SAFE_NO_PAD.encode(Sha256::digest(code_verifier.as_bytes()));
285 PkcePair {
286 code_verifier,
287 code_challenge: challenge,
288 }
289 }
290
291 pub fn authorize_url(&self, client: &Client, params: AuthorizeParams) -> Url {
292 let mut url = self.base_url.join("authorize").unwrap();
293 let mut query = url.query_pairs_mut();
294
295 query
296 .append_pair("response_type", params.response_type)
297 .append_pair("client_id", &client.client_id)
298 .append_pair("redirect_uri", ¶ms.redirect_uri)
299 .append_pair("scope", ¶ms.scope);
300
301 if let Some(state) = params.state {
302 query.append_pair("state", &state);
303 }
304
305 if let Some(ref response_mode) = params.response_mode {
306 query.append_pair("response_mode", response_mode);
307 }
308
309 if let Some(pkce) = params.pkce {
310 query
311 .append_pair("code_challenge", &pkce.code_challenge)
312 .append_pair("code_challenge_method", "S256");
313 }
314
315 if let Some(nonce) = params.nonce {
316 query.append_pair("nonce", &nonce);
317 }
318
319 if let Some(ref prompt) = params.prompt {
320 query.append_pair("prompt", prompt);
321 }
322
323 if let Some(ref max_age) = params.max_age {
324 query.append_pair("max_age", max_age);
325 }
326
327 if let Some(ref claims) = params.claims {
328 query.append_pair("claims", claims);
329 }
330
331 if let Some(ref ui_locales) = params.ui_locales {
332 query.append_pair("ui_locales", ui_locales);
333 }
334
335 drop(query);
336 url
337 }
338
339 pub fn rotate_keys(&self) {
340 unimplemented!("Key rotation not implemented in test server")
342 }
343
344 pub async fn approve_consent(&self, auth_url: &Url, user_id: &str) -> String {
345 let resp = self.http.get(auth_url.clone()).send().await.unwrap();
346 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
347
348 let location = resp.headers().get("location").unwrap().to_str().unwrap();
349 let redirect = Url::parse(location).unwrap();
350 let code = redirect
351 .query_pairs()
352 .find(|(k, _)| k == "code")
353 .map(|(_, v)| v.to_string())
354 .expect("no code in redirect");
355
356 let code_obj = self
358 .state
359 .store
360 .get_code(&code)
361 .await
362 .expect("code not found");
363 let mut code_obj = code_obj.clone();
364 code_obj.user_id = user_id.to_string();
365 self.state.store.insert_code(code.clone(), code_obj).await;
366
367 code
368 }
369
370 pub async fn exchange_code(
371 &self,
372 client: &Client,
373 code: &str,
374 pkce: Option<&PkcePair>,
375 ) -> Value {
376 let mut form = vec![
377 ("grant_type", "authorization_code"),
378 ("code", code),
379 ("redirect_uri", "http://localhost/cb"),
380 ];
381
382 if let Some(pkce) = pkce {
383 form.push(("code_verifier", &pkce.code_verifier));
384 }
385
386 let resp = self
387 .http
388 .post(self.base_url.join("token").unwrap())
389 .basic_auth(&client.client_id, client.client_secret.as_ref())
390 .form(&form)
391 .send()
392 .await
393 .unwrap();
394
395 assert_eq!(resp.status(), StatusCode::OK);
396 resp.json().await.unwrap()
397 }
398
399 pub async fn refresh_token(&self, client: &Client, refresh_token: &str) -> Value {
400 let resp = self
401 .http
402 .post(self.base_url.join("token").unwrap())
403 .basic_auth(&client.client_id, client.client_secret.as_ref())
404 .form(&[
405 ("grant_type", "refresh_token"),
406 ("refresh_token", refresh_token),
407 ])
408 .send()
409 .await
410 .unwrap();
411
412 resp.json().await.unwrap()
413 }
414
415 pub async fn introspect_token(&self, client: &Client, token: &str) -> Value {
416 let resp = self
417 .http
418 .post(self.base_url.join("introspect").unwrap())
419 .basic_auth(&client.client_id, client.client_secret.as_ref())
420 .form(&[("token", token)])
421 .send()
422 .await
423 .unwrap();
424
425 resp.json().await.unwrap()
426 }
427
428 pub async fn revoke_token(&self, client: &Client, token: &str) {
429 let resp = self
430 .http
431 .post(self.base_url.join("revoke").unwrap())
432 .basic_auth(&client.client_id, client.client_secret.as_ref())
433 .form(&[("token", token)])
434 .send()
435 .await
436 .unwrap();
437
438 assert!(resp.status().is_success());
439 }
440
441 pub fn client_assertion_jwt(&self, client: &Client) -> String {
442 let claims = json!({
443 "iss": client.client_id,
444 "sub": client.client_id,
445 "aud": self.issuer(),
446 "exp": (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp(),
447 "iat": chrono::Utc::now().timestamp(),
448 "jti": Uuid::new_v4().to_string(),
449 });
450
451 let mut header = Header::new(Algorithm::RS256);
452 header.kid = Some(self.state.keys.kid.clone());
453
454 encode(&header, &claims, &self.state.keys.encoding).unwrap()
455 }
456
457 pub fn base_url(&self) -> &url::Url {
458 &self.base_url
459 }
460
461 pub fn issuer(&self) -> &str {
462 self.state.issuer()
463 }
464
465 pub async fn complete_auth_flow(
468 &self,
469 client: &Client,
470 params: AuthorizeParams,
471 user_id: &str,
472 ) -> Value {
473 let pkce = params.pkce.clone();
474
475 let auth_url = self.authorize_url(client, params);
476 let code = self.approve_consent(&auth_url, user_id).await;
477
478 self.exchange_code(client, &code, pkce.as_ref()).await
479 }
480
481 pub async fn complete_device_flow(&self, client: &Client, scope: &str, user_id: &str) -> Value {
484 let scope_str = scope.to_string();
485 let device_resp = self
486 .http
487 .post(self.base_url.join("device/code").unwrap())
488 .form(&[("client_id", &client.client_id), ("scope", &scope_str)])
489 .send()
490 .await
491 .unwrap();
492
493 let device_data: Value = device_resp.json().await.unwrap();
494 let device_code = device_data["device_code"].as_str().unwrap();
495
496 self.approve_device_code(device_code, user_id).await;
497
498 let token_resp = self
499 .http
500 .post(self.base_url.join("device/token").unwrap())
501 .form(&[
502 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
503 ("device_code", device_code),
504 ("client_id", &client.client_id),
505 ])
506 .send()
507 .await
508 .unwrap();
509
510 token_resp.json().await.unwrap()
511 }
512
513 pub async fn client_credentials_grant(&self, client: &Client, scope: Option<&str>) -> Value {
515 let mut form = vec![("grant_type", "client_credentials")];
516
517 if let Some(s) = scope {
518 form.push(("scope", s));
519 }
520
521 let resp = self
522 .http
523 .post(self.base_url.join("token").unwrap())
524 .basic_auth(&client.client_id, client.client_secret.as_ref())
525 .form(&form)
526 .send()
527 .await
528 .unwrap();
529
530 assert_eq!(resp.status(), StatusCode::OK);
531 resp.json().await.unwrap()
532 }
533}
534
535#[derive(Debug, Default)]
536pub struct JwtOptions {
537 pub user_id: String,
538 pub scope: Option<String>,
539 pub expires_in: i64,
540}
541
542#[derive(Default)]
543pub struct JwtOptionsBuilder {
544 user_id: Option<String>,
545 scope: Option<String>,
546 expires_in: Option<i64>,
547}
548
549impl JwtOptionsBuilder {
550 pub fn user_id(mut self, user_id: impl Into<String>) -> Self {
551 self.user_id = Some(user_id.into());
552 self
553 }
554
555 pub fn scope(mut self, scope: impl Into<String>) -> Self {
556 self.scope = Some(scope.into());
557 self
558 }
559
560 pub fn expires_in(mut self, seconds: i64) -> Self {
561 self.expires_in = Some(seconds);
562 self
563 }
564
565 pub fn build(self) -> JwtOptions {
566 JwtOptions {
567 user_id: self.user_id.unwrap_or("test-user-123".to_string()),
568 scope: self.scope,
569 expires_in: self.expires_in.unwrap_or(3600),
570 }
571 }
572}
573
574use crate::config::IssuerConfig;
575use crate::models::{AuthorizationCode, Client, Token};
576use crate::store::AppState;