Skip to main content

oauth2_test_server/
testkit.rs

1use base64::{engine::general_purpose, Engine};
2use jsonwebtoken::{encode, Algorithm, Header};
3use rand::Rng;
4use reqwest::{Client as HttpClient, StatusCode};
5use serde_json::{json, Value};
6use sha2::{Digest, Sha256};
7use tokio::task::JoinError;
8use url::Url;
9use uuid::Uuid;
10
11#[derive(Debug, Clone)]
12pub struct PkcePair {
13    pub code_verifier: String,
14    pub code_challenge: String,
15}
16
17#[derive(Debug, Default)]
18pub struct AuthorizeParams {
19    pub response_type: &'static str,
20    pub redirect_uri: String,
21    pub scope: String,
22    pub state: Option<String>,
23    pub response_mode: Option<String>,
24    pub pkce: Option<PkcePair>,
25    pub nonce: Option<String>,
26    pub prompt: Option<String>,
27    pub max_age: Option<String>,
28    pub claims: Option<String>,
29    pub ui_locales: Option<String>,
30}
31
32impl AuthorizeParams {
33    pub fn new() -> Self {
34        Self {
35            response_type: "code",
36            redirect_uri: "http://localhost/cb".to_string(),
37            scope: "openid".to_string(),
38            state: Some(Uuid::new_v4().to_string()),
39            response_mode: None,
40            pkce: None,
41            nonce: None,
42            prompt: None,
43            max_age: None,
44            claims: None,
45            ui_locales: None,
46        }
47    }
48
49    pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
50        self.redirect_uri = uri.into();
51        self
52    }
53
54    pub fn scope(mut self, scope: impl Into<String>) -> Self {
55        self.scope = scope.into();
56        self
57    }
58
59    pub fn state(mut self, state: impl Into<String>) -> Self {
60        self.state = Some(state.into());
61        self
62    }
63
64    pub fn no_state(mut self) -> Self {
65        self.state = None;
66        self
67    }
68
69    pub fn response_mode(mut self, mode: impl Into<String>) -> Self {
70        self.response_mode = Some(mode.into());
71        self
72    }
73
74    pub fn pkce(mut self, pkce: PkcePair) -> Self {
75        self.pkce = Some(pkce);
76        self
77    }
78
79    pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
80        self.nonce = Some(nonce.into());
81        self
82    }
83
84    pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
85        self.prompt = Some(prompt.into());
86        self
87    }
88
89    pub fn max_age(mut self, max_age: impl Into<String>) -> Self {
90        self.max_age = Some(max_age.into());
91        self
92    }
93
94    pub fn claims(mut self, claims: impl Into<String>) -> Self {
95        self.claims = Some(claims.into());
96        self
97    }
98
99    pub fn ui_locales(mut self, locales: impl Into<String>) -> Self {
100        self.ui_locales = Some(locales.into());
101        self
102    }
103}
104
105#[derive(Debug, Clone)]
106pub struct OauthEndpoints {
107    pub oauth_server: String,
108    pub discovery: String,
109    pub authorize: String,
110    pub token: String,
111    pub register: String,
112    pub introspect: String,
113    pub revoke: String,
114    pub userinfo: String,
115    pub jwks: String,
116    pub device_code: String,
117    pub device_token: String,
118}
119
120/// Start a test server with full programmatic control.
121///
122/// ```
123/// use oauth2_test_server::OAuthTestServer;
124///
125/// #[tokio::test]
126/// async fn test() {
127/// let server = OAuthTestServer::start().await;
128/// println!("server: {}", server.base_url());
129/// println!("authorize endpoint: {}", server.endpoints.authorize_url);
130/// // register a client
131/// let client = server.register_client(
132///     serde_json::json!({ "scope": "openid", "redirect_uris":["http://localhost:8080/callback"]}),
133/// );
134/// // generate a jwt
135/// let jwt = server.generate_jwt(&client, server.jwt_options().user_id("bob").build());
136/// assert_eq!(jwt.split('.').count(), 3);
137/// assert_eq!(server.clients().read().iter().len(), 1);
138/// assert_eq!(server.tokens().read().iter().len(), 1);
139/// }
140/// ```
141pub struct OAuthTestServer {
142    state: AppState,
143    pub base_url: url::Url,
144    pub endpoints: OauthEndpoints,
145    pub http: HttpClient,
146    _handle: tokio::task::JoinHandle<()>,
147}
148
149impl OAuthTestServer {
150    pub async fn start() -> Self {
151        let config = IssuerConfig {
152            port: 0,
153            ..Default::default()
154        };
155        Self::start_with_config(config).await
156    }
157
158    pub async fn start_with_config(config: IssuerConfig) -> Self {
159        // config.port = 0;
160        let mut state = AppState::new(config.clone());
161        let (addr, handle) = state.clone().start().await;
162        let base_url: Url = format!("http://{addr}").parse().unwrap();
163        state.base_url = base_url.to_string().trim_end_matches("/").to_string();
164        let endpoints: OauthEndpoints = OauthEndpoints {
165            oauth_server: base_url.clone().to_string(),
166            discovery: format!("{base_url}.well-known/openid-configuration"),
167            register: format!("{base_url}register"),
168            authorize: format!("{base_url}authorize"),
169            token: format!("{base_url}token"),
170            introspect: format!("{base_url}introspect"),
171            revoke: format!("{base_url}revoke"),
172            userinfo: format!("{base_url}userinfo"),
173            jwks: format!("{base_url}.well-known/jwks.json"),
174            device_code: format!("{base_url}device/code"),
175            device_token: format!("{base_url}device/token"),
176        };
177
178        Self {
179            state,
180            base_url,
181            endpoints,
182            http: HttpClient::builder()
183                .redirect(reqwest::redirect::Policy::none())
184                .build()
185                .unwrap(),
186            _handle: handle,
187        }
188    }
189
190    pub async fn wait_for_shutdown(self) -> Result<(), JoinError> {
191        self._handle.await
192    }
193
194    pub async fn register_client(&self, metadata: serde_json::Value) -> Client {
195        self.state
196            .register_client(metadata)
197            .await
198            .expect("client registration failed")
199    }
200
201    pub async fn register_client_with_secret(&self, metadata: Value, force_secret: bool) -> Client {
202        let mut meta = metadata;
203        if let Some(obj) = meta.as_object_mut() {
204            obj.insert(
205                "generate_client_secret_for_dcr".to_string(),
206                json!(force_secret),
207            );
208        }
209        self.register_client(meta).await
210    }
211
212    pub fn generate_jwt(&self, client: &Client, options: JwtOptions) -> String {
213        self.state
214            .generate_jwt(client, options)
215            .expect("JWT generation failed")
216    }
217
218    pub async fn generate_token(&self, client: &Client, options: JwtOptions) -> Token {
219        self.state
220            .generate_token(client, options)
221            .await
222            .expect("Token generation failed")
223    }
224
225    pub async fn clients(&self) -> Vec<Client> {
226        self.state.store.get_all_clients().await
227    }
228
229    pub async fn codes(&self) -> Vec<AuthorizationCode> {
230        self.state.store.get_all_codes().await
231    }
232
233    pub async fn tokens(&self) -> Vec<Token> {
234        self.state.store.get_all_tokens().await
235    }
236
237    pub async fn refresh_tokens(&self) -> Vec<Token> {
238        self.state.store.get_all_refresh_tokens().await
239    }
240
241    pub async fn clear_clients(&self) {
242        self.state.store.clear_clients().await;
243    }
244
245    pub async fn clear_codes(&self) {
246        self.state.store.clear_codes().await;
247    }
248
249    pub async fn clear_tokens(&self) {
250        self.state.store.clear_tokens().await;
251    }
252
253    pub async fn clear_refresh_tokens(&self) {
254        self.state.store.clear_refresh_tokens().await;
255    }
256
257    pub async fn clear_device_codes(&self) {
258        self.state.store.clear_device_codes().await;
259    }
260
261    pub async fn clear_all(&self) {
262        self.state.store.clear_all().await;
263    }
264
265    pub async fn approve_device_code(&self, device_code: &str, user_id: &str) {
266        self.state
267            .approve_device_code(device_code, user_id)
268            .await
269            .expect("device code not found");
270    }
271
272    pub fn state(&self) -> &AppState {
273        &self.state
274    }
275
276    pub fn jwt_options(&self) -> JwtOptionsBuilder {
277        JwtOptionsBuilder::default()
278    }
279
280    pub fn pkce_pair(&self) -> PkcePair {
281        let verifier_bytes: [u8; 32] = rand::thread_rng().r#gen();
282        let code_verifier = general_purpose::URL_SAFE_NO_PAD.encode(verifier_bytes);
283        let challenge =
284            general_purpose::URL_SAFE_NO_PAD.encode(Sha256::digest(code_verifier.as_bytes()));
285        PkcePair {
286            code_verifier,
287            code_challenge: challenge,
288        }
289    }
290
291    pub fn authorize_url(&self, client: &Client, params: AuthorizeParams) -> Url {
292        let mut url = self.base_url.join("authorize").unwrap();
293        let mut query = url.query_pairs_mut();
294
295        query
296            .append_pair("response_type", params.response_type)
297            .append_pair("client_id", &client.client_id)
298            .append_pair("redirect_uri", &params.redirect_uri)
299            .append_pair("scope", &params.scope);
300
301        if let Some(state) = params.state {
302            query.append_pair("state", &state);
303        }
304
305        if let Some(ref response_mode) = params.response_mode {
306            query.append_pair("response_mode", response_mode);
307        }
308
309        if let Some(pkce) = params.pkce {
310            query
311                .append_pair("code_challenge", &pkce.code_challenge)
312                .append_pair("code_challenge_method", "S256");
313        }
314
315        if let Some(nonce) = params.nonce {
316            query.append_pair("nonce", &nonce);
317        }
318
319        if let Some(ref prompt) = params.prompt {
320            query.append_pair("prompt", prompt);
321        }
322
323        if let Some(ref max_age) = params.max_age {
324            query.append_pair("max_age", max_age);
325        }
326
327        if let Some(ref claims) = params.claims {
328            query.append_pair("claims", claims);
329        }
330
331        if let Some(ref ui_locales) = params.ui_locales {
332            query.append_pair("ui_locales", ui_locales);
333        }
334
335        drop(query);
336        url
337    }
338
339    pub fn rotate_keys(&self) {
340        // In real impl, regenerate KEYS and update JWKS_JSON
341        unimplemented!("Key rotation not implemented in test server")
342    }
343
344    pub async fn approve_consent(&self, auth_url: &Url, user_id: &str) -> String {
345        let resp = self.http.get(auth_url.clone()).send().await.unwrap();
346        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
347
348        let location = resp.headers().get("location").unwrap().to_str().unwrap();
349        let redirect = Url::parse(location).unwrap();
350        let code = redirect
351            .query_pairs()
352            .find(|(k, _)| k == "code")
353            .map(|(_, v)| v.to_string())
354            .expect("no code in redirect");
355
356        // Store user_id for later token claims
357        let code_obj = self
358            .state
359            .store
360            .get_code(&code)
361            .await
362            .expect("code not found");
363        let mut code_obj = code_obj.clone();
364        code_obj.user_id = user_id.to_string();
365        self.state.store.insert_code(code.clone(), code_obj).await;
366
367        code
368    }
369
370    pub async fn exchange_code(
371        &self,
372        client: &Client,
373        code: &str,
374        pkce: Option<&PkcePair>,
375    ) -> Value {
376        let mut form = vec![
377            ("grant_type", "authorization_code"),
378            ("code", code),
379            ("redirect_uri", "http://localhost/cb"),
380        ];
381
382        if let Some(pkce) = pkce {
383            form.push(("code_verifier", &pkce.code_verifier));
384        }
385
386        let resp = self
387            .http
388            .post(self.base_url.join("token").unwrap())
389            .basic_auth(&client.client_id, client.client_secret.as_ref())
390            .form(&form)
391            .send()
392            .await
393            .unwrap();
394
395        assert_eq!(resp.status(), StatusCode::OK);
396        resp.json().await.unwrap()
397    }
398
399    pub async fn refresh_token(&self, client: &Client, refresh_token: &str) -> Value {
400        let resp = self
401            .http
402            .post(self.base_url.join("token").unwrap())
403            .basic_auth(&client.client_id, client.client_secret.as_ref())
404            .form(&[
405                ("grant_type", "refresh_token"),
406                ("refresh_token", refresh_token),
407            ])
408            .send()
409            .await
410            .unwrap();
411
412        resp.json().await.unwrap()
413    }
414
415    pub async fn introspect_token(&self, client: &Client, token: &str) -> Value {
416        let resp = self
417            .http
418            .post(self.base_url.join("introspect").unwrap())
419            .basic_auth(&client.client_id, client.client_secret.as_ref())
420            .form(&[("token", token)])
421            .send()
422            .await
423            .unwrap();
424
425        resp.json().await.unwrap()
426    }
427
428    pub async fn revoke_token(&self, client: &Client, token: &str) {
429        let resp = self
430            .http
431            .post(self.base_url.join("revoke").unwrap())
432            .basic_auth(&client.client_id, client.client_secret.as_ref())
433            .form(&[("token", token)])
434            .send()
435            .await
436            .unwrap();
437
438        assert!(resp.status().is_success());
439    }
440
441    pub fn client_assertion_jwt(&self, client: &Client) -> String {
442        let claims = json!({
443            "iss": client.client_id,
444            "sub": client.client_id,
445            "aud": self.issuer(),
446            "exp": (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp(),
447            "iat": chrono::Utc::now().timestamp(),
448            "jti": Uuid::new_v4().to_string(),
449        });
450
451        let mut header = Header::new(Algorithm::RS256);
452        header.kid = Some(self.state.keys.kid.clone());
453
454        encode(&header, &claims, &self.state.keys.encoding).unwrap()
455    }
456
457    pub fn base_url(&self) -> &url::Url {
458        &self.base_url
459    }
460
461    pub fn issuer(&self) -> &str {
462        self.state.issuer()
463    }
464
465    /// Complete the full authorization code flow with PKCE in one call.
466    /// Returns the token response including access_token, refresh_token, and optionally id_token.
467    pub async fn complete_auth_flow(
468        &self,
469        client: &Client,
470        params: AuthorizeParams,
471        user_id: &str,
472    ) -> Value {
473        let pkce = params.pkce.clone();
474
475        let auth_url = self.authorize_url(client, params);
476        let code = self.approve_consent(&auth_url, user_id).await;
477
478        self.exchange_code(client, &code, pkce.as_ref()).await
479    }
480
481    /// Complete the full device code flow in one call.
482    /// Returns the token response.
483    pub async fn complete_device_flow(&self, client: &Client, scope: &str, user_id: &str) -> Value {
484        let scope_str = scope.to_string();
485        let device_resp = self
486            .http
487            .post(self.base_url.join("device/code").unwrap())
488            .form(&[("client_id", &client.client_id), ("scope", &scope_str)])
489            .send()
490            .await
491            .unwrap();
492
493        let device_data: Value = device_resp.json().await.unwrap();
494        let device_code = device_data["device_code"].as_str().unwrap();
495
496        self.approve_device_code(device_code, user_id).await;
497
498        let token_resp = self
499            .http
500            .post(self.base_url.join("device/token").unwrap())
501            .form(&[
502                ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
503                ("device_code", device_code),
504                ("client_id", &client.client_id),
505            ])
506            .send()
507            .await
508            .unwrap();
509
510        token_resp.json().await.unwrap()
511    }
512
513    /// Perform client credentials grant.
514    pub async fn client_credentials_grant(&self, client: &Client, scope: Option<&str>) -> Value {
515        let mut form = vec![("grant_type", "client_credentials")];
516
517        if let Some(s) = scope {
518            form.push(("scope", s));
519        }
520
521        let resp = self
522            .http
523            .post(self.base_url.join("token").unwrap())
524            .basic_auth(&client.client_id, client.client_secret.as_ref())
525            .form(&form)
526            .send()
527            .await
528            .unwrap();
529
530        assert_eq!(resp.status(), StatusCode::OK);
531        resp.json().await.unwrap()
532    }
533}
534
535#[derive(Debug, Default)]
536pub struct JwtOptions {
537    pub user_id: String,
538    pub scope: Option<String>,
539    pub expires_in: i64,
540}
541
542#[derive(Default)]
543pub struct JwtOptionsBuilder {
544    user_id: Option<String>,
545    scope: Option<String>,
546    expires_in: Option<i64>,
547}
548
549impl JwtOptionsBuilder {
550    pub fn user_id(mut self, user_id: impl Into<String>) -> Self {
551        self.user_id = Some(user_id.into());
552        self
553    }
554
555    pub fn scope(mut self, scope: impl Into<String>) -> Self {
556        self.scope = Some(scope.into());
557        self
558    }
559
560    pub fn expires_in(mut self, seconds: i64) -> Self {
561        self.expires_in = Some(seconds);
562        self
563    }
564
565    pub fn build(self) -> JwtOptions {
566        JwtOptions {
567            user_id: self.user_id.unwrap_or("test-user-123".to_string()),
568            scope: self.scope,
569            expires_in: self.expires_in.unwrap_or(3600),
570        }
571    }
572}
573
574use crate::config::IssuerConfig;
575use crate::models::{AuthorizationCode, Client, Token};
576use crate::store::AppState;