oauth2_test_server/
testkit.rs

1use base64::{engine::general_purpose, Engine};
2use jsonwebtoken::{encode, Algorithm, Header};
3use reqwest::{Client as HttpClient, StatusCode};
4use serde_json::{json, Value};
5use sha2::{Digest, Sha256};
6use std::{
7    collections::HashMap,
8    sync::{Arc, RwLock},
9};
10use tokio::task::JoinError;
11use url::Url;
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
15pub struct PkcePair {
16    pub code_verifier: String,
17    pub code_challenge: String,
18}
19
20#[derive(Debug, Default)]
21pub struct AuthorizeParams {
22    pub response_type: &'static str,
23    pub redirect_uri: String,
24    pub scope: String,
25    pub state: String,
26    pub pkce: Option<PkcePair>,
27    pub nonce: Option<String>,
28}
29
30impl AuthorizeParams {
31    pub fn new() -> Self {
32        Self {
33            response_type: "code",
34            redirect_uri: "http://localhost/cb".to_string(),
35            scope: "openid".to_string(),
36            state: Uuid::new_v4().to_string(),
37            pkce: None,
38            nonce: None,
39        }
40    }
41
42    pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
43        self.redirect_uri = uri.into();
44        self
45    }
46
47    pub fn scope(mut self, scope: impl Into<String>) -> Self {
48        self.scope = scope.into();
49        self
50    }
51
52    pub fn state(mut self, state: impl Into<String>) -> Self {
53        self.state = state.into();
54        self
55    }
56
57    pub fn pkce(mut self, pkce: PkcePair) -> Self {
58        self.pkce = Some(pkce);
59        self
60    }
61
62    pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
63        self.nonce = Some(nonce.into());
64        self
65    }
66}
67
68#[derive(Debug, Clone)]
69pub struct OauthEndpoints {
70    pub oauth_server: String,
71    pub discovery: String,
72    pub authorize: String,
73    pub token: String,
74    pub regsiter: String,
75    pub introspect: String,
76    pub revoke: String,
77    pub userinfo: String,
78    pub jwks: String,
79}
80
81/// Start a test server with full programmatic control.
82///
83/// ```
84/// use oauth2_test_server::OAuthTestServer;
85///
86/// #[tokio::test]
87/// async fn test() {
88/// let server = OAuthTestServer::start().await;
89/// println!("server: {}", server.base_url());
90/// println!("authorize endpoint: {}", server.endpoints.authorize_url);
91/// // register a client
92/// let client = server.register_client(
93///     serde_json::json!({ "scope": "openid", "redirect_uris":["http://localhost:8080/callback"]}),
94/// );
95/// // generate a jwt
96/// let jwt = server.generate_jwt(&client, server.jwt_options().user_id("bob").build());
97/// assert_eq!(jwt.split('.').count(), 3);
98/// assert_eq!(server.clients().read().iter().len(), 1);
99/// assert_eq!(server.tokens().read().iter().len(), 1);
100/// }
101/// ```
102pub struct OAuthTestServer {
103    state: AppState,
104    pub base_url: url::Url,
105    pub endpoints: OauthEndpoints,
106    http: HttpClient,
107    _handle: tokio::task::JoinHandle<()>,
108}
109
110impl OAuthTestServer {
111    pub async fn start() -> Self {
112        let config = IssuerConfig {
113            port: 0,
114            ..Default::default()
115        };
116        Self::start_with_config(config).await
117    }
118
119    pub fn clients(&self) -> Arc<RwLock<HashMap<String, Client>>> {
120        self.state.clients.clone()
121    }
122
123    pub fn codes(&self) -> Arc<RwLock<HashMap<String, AuthorizationCode>>> {
124        self.state.codes.clone()
125    }
126
127    pub fn tokens(&self) -> Arc<RwLock<HashMap<String, Token>>> {
128        self.state.tokens.clone()
129    }
130
131    pub fn refresh_tokens(&self) -> Arc<RwLock<HashMap<String, Token>>> {
132        self.state.refresh_tokens.clone()
133    }
134
135    pub async fn start_with_config(config: IssuerConfig) -> Self {
136        // config.port = 0;
137        let mut state = AppState::new(config.clone());
138        let (addr, handle) = state.clone().start().await;
139        let base_url: Url = format!("http://{addr}").parse().unwrap();
140        state.base_url = base_url.to_string().trim_end_matches("/").to_string();
141        let endpoints: OauthEndpoints = OauthEndpoints {
142            oauth_server: base_url.clone().to_string(),
143            discovery: format!("{base_url}.well-known/openid-configuration"),
144            authorize: format!("{base_url}register"),
145            regsiter: format!("{base_url}authorize"),
146            token: format!("{base_url}token"),
147            introspect: format!("{base_url}introspect"),
148            revoke: format!("{base_url}revoke"),
149            userinfo: format!("{base_url}userinfo"),
150            jwks: format!("{base_url}.well-known/jwks.json"),
151        };
152
153        Self {
154            state,
155            base_url,
156            endpoints,
157            http: HttpClient::new(),
158            _handle: handle,
159        }
160    }
161
162    pub async fn wait_for_shutdown(self) -> Result<(), JoinError> {
163        self._handle.await
164    }
165
166    pub fn register_client(&self, metadata: serde_json::Value) -> Client {
167        self.state
168            .register_client(metadata)
169            .expect("client registration failed")
170    }
171
172    pub fn register_client_with_secret(&self, metadata: Value, force_secret: bool) -> Client {
173        let mut meta = metadata;
174        if let Some(obj) = meta.as_object_mut() {
175            obj.insert(
176                "generate_client_secret_for_dcr".to_string(),
177                json!(force_secret),
178            );
179        }
180        self.register_client(meta)
181    }
182
183    pub fn generate_jwt(&self, client: &Client, options: JwtOptions) -> String {
184        self.state
185            .generate_jwt(client, options)
186            .expect("JWT generation failed")
187    }
188
189    pub fn generate_token(&self, client: &Client, options: JwtOptions) -> Token {
190        self.state
191            .generate_token(client, options)
192            .expect("Token generation failed")
193    }
194
195    pub fn jwt_options(&self) -> JwtOptionsBuilder {
196        JwtOptionsBuilder::default()
197    }
198
199    pub fn pkce_pair(&self) -> PkcePair {
200        use rand::Rng;
201        let verifier_bytes: [u8; 32] = rand::thread_rng().r#gen();
202        let code_verifier = general_purpose::URL_SAFE_NO_PAD.encode(verifier_bytes);
203        let challenge =
204            general_purpose::URL_SAFE_NO_PAD.encode(Sha256::digest(code_verifier.as_bytes()));
205        PkcePair {
206            code_verifier,
207            code_challenge: challenge,
208        }
209    }
210
211    pub fn authorize_url(&self, client: &Client, params: AuthorizeParams) -> Url {
212        let mut url = self.base_url.join("authorize").unwrap();
213        let mut query = url.query_pairs_mut();
214
215        query
216            .append_pair("response_type", params.response_type)
217            .append_pair("client_id", &client.client_id)
218            .append_pair("redirect_uri", &params.redirect_uri)
219            .append_pair("scope", &params.scope)
220            .append_pair("state", &params.state);
221
222        if let Some(pkce) = params.pkce {
223            query
224                .append_pair("code_challenge", &pkce.code_challenge)
225                .append_pair("code_challenge_method", "S256");
226        }
227
228        if let Some(nonce) = params.nonce {
229            query.append_pair("nonce", &nonce);
230        }
231
232        drop(query);
233        url
234    }
235
236    pub fn rotate_keys(&self) {
237        // In real impl, regenerate KEYS and update JWKS_JSON
238        unimplemented!("Key rotation not implemented in test server")
239    }
240
241    pub async fn approve_consent(&self, auth_url: &Url, user_id: &str) -> String {
242        let resp = self.http.get(auth_url.clone()).send().await.unwrap();
243        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
244
245        let location = resp.headers().get("location").unwrap().to_str().unwrap();
246        let redirect = Url::parse(location).unwrap();
247        let code = redirect
248            .query_pairs()
249            .find(|(k, _)| k == "code")
250            .map(|(_, v)| v.to_string())
251            .expect("no code in redirect");
252
253        // Store user_id for later token claims
254        let code_obj = self
255            .state
256            .codes
257            .read()
258            .unwrap()
259            .get(&code)
260            .cloned()
261            .unwrap();
262        let mut code_obj = code_obj;
263        code_obj.user_id = user_id.to_string();
264        self.state
265            .codes
266            .write()
267            .unwrap()
268            .insert(code.clone(), code_obj);
269
270        code
271    }
272
273    pub async fn exchange_code(
274        &self,
275        client: &Client,
276        code: &str,
277        pkce: Option<&PkcePair>,
278    ) -> Value {
279        let mut form = vec![
280            ("grant_type", "authorization_code"),
281            ("code", code),
282            ("redirect_uri", "http://localhost/cb"),
283        ];
284
285        if let Some(pkce) = pkce {
286            form.push(("code_verifier", &pkce.code_verifier));
287        }
288
289        let resp = self
290            .http
291            .post(self.base_url.join("token").unwrap())
292            .basic_auth(&client.client_id, client.client_secret.as_ref())
293            .form(&form)
294            .send()
295            .await
296            .unwrap();
297
298        assert_eq!(resp.status(), StatusCode::OK);
299        resp.json().await.unwrap()
300    }
301
302    pub async fn refresh_token(&self, client: &Client, refresh_token: &str) -> Value {
303        let resp = self
304            .http
305            .post(self.base_url.join("token").unwrap())
306            .basic_auth(&client.client_id, client.client_secret.as_ref())
307            .form(&[
308                ("grant_type", "refresh_token"),
309                ("refresh_token", refresh_token),
310            ])
311            .send()
312            .await
313            .unwrap();
314
315        resp.json().await.unwrap()
316    }
317
318    pub async fn introspect_token(&self, client: &Client, token: &str) -> Value {
319        let resp = self
320            .http
321            .post(self.base_url.join("introspect").unwrap())
322            .basic_auth(&client.client_id, client.client_secret.as_ref())
323            .form(&[("token", token)])
324            .send()
325            .await
326            .unwrap();
327
328        resp.json().await.unwrap()
329    }
330
331    pub async fn revoke_token(&self, client: &Client, token: &str) {
332        let resp = self
333            .http
334            .post(self.base_url.join("revoke").unwrap())
335            .basic_auth(&client.client_id, client.client_secret.as_ref())
336            .form(&[("token", token)])
337            .send()
338            .await
339            .unwrap();
340
341        assert!(resp.status().is_success());
342    }
343
344    pub fn client_assertion_jwt(&self, client: &Client) -> String {
345        let claims = json!({
346            "iss": client.client_id,
347            "sub": client.client_id,
348            "aud": self.issuer(),
349            "exp": (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp(),
350            "iat": chrono::Utc::now().timestamp(),
351            "jti": Uuid::new_v4().to_string(),
352        });
353
354        let mut header = Header::new(Algorithm::RS256);
355        header.kid = Some(KID.to_string());
356
357        encode(&header, &claims, &KEYS.encoding).unwrap()
358    }
359
360    pub fn base_url(&self) -> &url::Url {
361        &self.base_url
362    }
363
364    pub fn issuer(&self) -> &str {
365        self.state.issuer()
366    }
367}
368
369#[derive(Debug, Default)]
370pub struct JwtOptions {
371    pub user_id: String,
372    pub scope: Option<String>,
373    pub expires_in: i64,
374}
375
376#[derive(Default)]
377pub struct JwtOptionsBuilder {
378    user_id: Option<String>,
379    scope: Option<String>,
380    expires_in: Option<i64>,
381}
382
383impl JwtOptionsBuilder {
384    pub fn user_id(mut self, user_id: impl Into<String>) -> Self {
385        self.user_id = Some(user_id.into());
386        self
387    }
388
389    pub fn scope(mut self, scope: impl Into<String>) -> Self {
390        self.scope = Some(scope.into());
391        self
392    }
393
394    pub fn expires_in(mut self, seconds: i64) -> Self {
395        self.expires_in = Some(seconds);
396        self
397    }
398
399    pub fn build(self) -> JwtOptions {
400        JwtOptions {
401            user_id: self.user_id.unwrap_or("test-user-123".to_string()),
402            scope: self.scope,
403            expires_in: self.expires_in.unwrap_or(3600),
404        }
405    }
406}
407
408use crate::server::{AppState, AuthorizationCode, Client, IssuerConfig, Token, KEYS, KID};