oauth2_test_server/
testkit.rs

1use base64::{engine::general_purpose, Engine};
2use jsonwebtoken::{encode, Algorithm, Header};
3use reqwest::{Client as HttpClient, StatusCode};
4use serde_json::{json, Value};
5use sha2::{Digest, Sha256};
6use std::{
7    collections::HashMap,
8    sync::{Arc, RwLock},
9};
10use tokio::task::JoinError;
11use url::Url;
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
15pub struct PkcePair {
16    pub code_verifier: String,
17    pub code_challenge: String,
18}
19
20#[derive(Debug, Default)]
21pub struct AuthorizeParams {
22    pub response_type: &'static str,
23    pub redirect_uri: String,
24    pub scope: String,
25    pub state: String,
26    pub pkce: Option<PkcePair>,
27    pub nonce: Option<String>,
28}
29
30impl AuthorizeParams {
31    pub fn new() -> Self {
32        Self {
33            response_type: "code",
34            redirect_uri: "http://localhost/cb".to_string(),
35            scope: "openid".to_string(),
36            state: Uuid::new_v4().to_string(),
37            pkce: None,
38            nonce: None,
39        }
40    }
41
42    pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
43        self.redirect_uri = uri.into();
44        self
45    }
46
47    pub fn scope(mut self, scope: impl Into<String>) -> Self {
48        self.scope = scope.into();
49        self
50    }
51
52    pub fn state(mut self, state: impl Into<String>) -> Self {
53        self.state = state.into();
54        self
55    }
56
57    pub fn pkce(mut self, pkce: PkcePair) -> Self {
58        self.pkce = Some(pkce);
59        self
60    }
61
62    pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
63        self.nonce = Some(nonce.into());
64        self
65    }
66}
67
68#[derive(Debug, Clone)]
69pub struct OauthEndpoints {
70    pub oauth_server: String,
71    pub discovery: String,
72    pub authorize: String,
73    pub token: String,
74    pub regsiter: String,
75    pub introspect: String,
76    pub revoke: String,
77    pub userinfo: String,
78    pub jwks: String,
79}
80
81/// Start a test server with full programmatic control.
82///
83/// ```
84/// use oauth2_test_server::OAuthTestServer;
85///
86/// #[tokio::test]
87/// async fn test() {
88/// let server = OAuthTestServer::start().await;
89/// println!("server: {}", server.base_url());
90/// println!("authorize endpoint: {}", server.endpoints.authorize_url);
91/// // register a client
92/// let client = server.register_client(
93///     serde_json::json!({ "scope": "openid", "redirect_uris":["http://localhost:8080/callback"]}),
94/// );
95/// // generate a jwt
96/// let jwt = server.generate_jwt(&client, server.jwt_options().user_id("bob").build());
97/// assert_eq!(jwt.split('.').count(), 3);
98/// assert_eq!(server.clients().read().iter().len(), 1);
99/// assert_eq!(server.tokens().read().iter().len(), 1);
100/// }
101/// ```
102pub struct OAuthTestServer {
103    state: AppState,
104    pub base_url: url::Url,
105    pub endpoints: OauthEndpoints,
106    http: HttpClient,
107    _handle: tokio::task::JoinHandle<()>,
108}
109
110impl OAuthTestServer {
111    pub async fn start() -> Self {
112        let config = IssuerConfig {
113            port: 0,
114            ..Default::default()
115        };
116        Self::start_with_config(config).await
117    }
118
119    pub fn clients(&self) -> Arc<RwLock<HashMap<String, Client>>> {
120        self.state.clients.clone()
121    }
122
123    pub fn codes(&self) -> Arc<RwLock<HashMap<String, AuthorizationCode>>> {
124        self.state.codes.clone()
125    }
126
127    pub fn tokens(&self) -> Arc<RwLock<HashMap<String, Token>>> {
128        self.state.tokens.clone()
129    }
130
131    pub fn refresh_tokens(&self) -> Arc<RwLock<HashMap<String, Token>>> {
132        self.state.refresh_tokens.clone()
133    }
134
135    pub async fn start_with_config(config: IssuerConfig) -> Self {
136        // config.port = 0;
137        let state = AppState::new(config.clone());
138        let (addr, handle) = state.clone().start().await;
139        let base_url: Url = format!("http://{addr}").parse().unwrap();
140
141        let endpoints: OauthEndpoints = OauthEndpoints {
142            oauth_server: base_url.clone().to_string(),
143            discovery: format!("{base_url}.well-known/openid-configuration"),
144            authorize: format!("{base_url}register"),
145            regsiter: format!("{base_url}authorize"),
146            token: format!("{base_url}token"),
147            introspect: format!("{base_url}introspect"),
148            revoke: format!("{base_url}revoke"),
149            userinfo: format!("{base_url}userinfo"),
150            jwks: format!("{base_url}.well-known/jwks.json"),
151        };
152
153        Self {
154            state,
155            base_url,
156            endpoints,
157            http: HttpClient::new(),
158            _handle: handle,
159        }
160    }
161
162    pub async fn wait_for_shutdown(self) -> Result<(), JoinError> {
163        self._handle.await
164    }
165
166    pub fn register_client(&self, metadata: serde_json::Value) -> Client {
167        self.state
168            .register_client(metadata)
169            .expect("client registration failed")
170    }
171
172    pub fn register_client_with_secret(&self, metadata: Value, force_secret: bool) -> Client {
173        let mut meta = metadata;
174        if let Some(obj) = meta.as_object_mut() {
175            obj.insert(
176                "generate_client_secret_for_dcr".to_string(),
177                json!(force_secret),
178            );
179        }
180        self.register_client(meta)
181    }
182
183    pub fn generate_jwt(&self, client: &Client, options: JwtOptions) -> String {
184        self.state
185            .generate_jwt(client, options)
186            .expect("JWT generation failed")
187    }
188
189    pub fn jwt_options(&self) -> JwtOptionsBuilder {
190        JwtOptionsBuilder::default()
191    }
192
193    pub fn pkce_pair(&self) -> PkcePair {
194        use rand::Rng;
195        let verifier_bytes: [u8; 32] = rand::thread_rng().r#gen();
196        let code_verifier = general_purpose::URL_SAFE_NO_PAD.encode(verifier_bytes);
197        let challenge =
198            general_purpose::URL_SAFE_NO_PAD.encode(Sha256::digest(code_verifier.as_bytes()));
199        PkcePair {
200            code_verifier,
201            code_challenge: challenge,
202        }
203    }
204
205    pub fn authorize_url(&self, client: &Client, params: AuthorizeParams) -> Url {
206        let mut url = self.base_url.join("authorize").unwrap();
207        let mut query = url.query_pairs_mut();
208
209        query
210            .append_pair("response_type", params.response_type)
211            .append_pair("client_id", &client.client_id)
212            .append_pair("redirect_uri", &params.redirect_uri)
213            .append_pair("scope", &params.scope)
214            .append_pair("state", &params.state);
215
216        if let Some(pkce) = params.pkce {
217            query
218                .append_pair("code_challenge", &pkce.code_challenge)
219                .append_pair("code_challenge_method", "S256");
220        }
221
222        if let Some(nonce) = params.nonce {
223            query.append_pair("nonce", &nonce);
224        }
225
226        drop(query);
227        url
228    }
229
230    pub fn rotate_keys(&self) {
231        // In real impl, regenerate KEYS and update JWKS_JSON
232        unimplemented!("Key rotation not implemented in test server")
233    }
234
235    pub async fn approve_consent(&self, auth_url: &Url, user_id: &str) -> String {
236        let resp = self.http.get(auth_url.clone()).send().await.unwrap();
237        assert_eq!(resp.status(), StatusCode::SEE_OTHER);
238
239        let location = resp.headers().get("location").unwrap().to_str().unwrap();
240        let redirect = Url::parse(location).unwrap();
241        let code = redirect
242            .query_pairs()
243            .find(|(k, _)| k == "code")
244            .map(|(_, v)| v.to_string())
245            .expect("no code in redirect");
246
247        // Store user_id for later token claims
248        let code_obj = self
249            .state
250            .codes
251            .read()
252            .unwrap()
253            .get(&code)
254            .cloned()
255            .unwrap();
256        let mut code_obj = code_obj;
257        code_obj.user_id = user_id.to_string();
258        self.state
259            .codes
260            .write()
261            .unwrap()
262            .insert(code.clone(), code_obj);
263
264        code
265    }
266
267    pub async fn exchange_code(
268        &self,
269        client: &Client,
270        code: &str,
271        pkce: Option<&PkcePair>,
272    ) -> Value {
273        let mut form = vec![
274            ("grant_type", "authorization_code"),
275            ("code", code),
276            ("redirect_uri", "http://localhost/cb"),
277        ];
278
279        if let Some(pkce) = pkce {
280            form.push(("code_verifier", &pkce.code_verifier));
281        }
282
283        let resp = self
284            .http
285            .post(self.base_url.join("token").unwrap())
286            .basic_auth(&client.client_id, client.client_secret.as_ref())
287            .form(&form)
288            .send()
289            .await
290            .unwrap();
291
292        assert_eq!(resp.status(), StatusCode::OK);
293        resp.json().await.unwrap()
294    }
295
296    pub async fn refresh_token(&self, client: &Client, refresh_token: &str) -> Value {
297        let resp = self
298            .http
299            .post(self.base_url.join("token").unwrap())
300            .basic_auth(&client.client_id, client.client_secret.as_ref())
301            .form(&[
302                ("grant_type", "refresh_token"),
303                ("refresh_token", refresh_token),
304            ])
305            .send()
306            .await
307            .unwrap();
308
309        resp.json().await.unwrap()
310    }
311
312    pub async fn introspect_token(&self, client: &Client, token: &str) -> Value {
313        let resp = self
314            .http
315            .post(self.base_url.join("introspect").unwrap())
316            .basic_auth(&client.client_id, client.client_secret.as_ref())
317            .form(&[("token", token)])
318            .send()
319            .await
320            .unwrap();
321
322        resp.json().await.unwrap()
323    }
324
325    pub async fn revoke_token(&self, client: &Client, token: &str) {
326        let resp = self
327            .http
328            .post(self.base_url.join("revoke").unwrap())
329            .basic_auth(&client.client_id, client.client_secret.as_ref())
330            .form(&[("token", token)])
331            .send()
332            .await
333            .unwrap();
334
335        assert!(resp.status().is_success());
336    }
337
338    pub fn client_assertion_jwt(&self, client: &Client) -> String {
339        let claims = json!({
340            "iss": client.client_id,
341            "sub": client.client_id,
342            "aud": self.issuer(),
343            "exp": (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp(),
344            "iat": chrono::Utc::now().timestamp(),
345            "jti": Uuid::new_v4().to_string(),
346        });
347
348        let mut header = Header::new(Algorithm::RS256);
349        header.kid = Some(KID.to_string());
350
351        encode(&header, &claims, &KEYS.encoding).unwrap()
352    }
353
354    pub fn base_url(&self) -> &url::Url {
355        &self.base_url
356    }
357
358    pub fn issuer(&self) -> &str {
359        self.state.issuer()
360    }
361}
362
363#[derive(Debug, Default)]
364pub struct JwtOptions {
365    pub user_id: String,
366    pub scope: Option<String>,
367    pub expires_in: i64,
368}
369
370#[derive(Default)]
371pub struct JwtOptionsBuilder {
372    user_id: Option<String>,
373    scope: Option<String>,
374    expires_in: Option<i64>,
375}
376
377impl JwtOptionsBuilder {
378    pub fn user_id(mut self, user_id: impl Into<String>) -> Self {
379        self.user_id = Some(user_id.into());
380        self
381    }
382
383    pub fn scope(mut self, scope: impl Into<String>) -> Self {
384        self.scope = Some(scope.into());
385        self
386    }
387
388    pub fn expires_in(mut self, seconds: i64) -> Self {
389        self.expires_in = Some(seconds);
390        self
391    }
392
393    pub fn build(self) -> JwtOptions {
394        JwtOptions {
395            user_id: self.user_id.unwrap_or("test-user-123".to_string()),
396            scope: self.scope,
397            expires_in: self.expires_in.unwrap_or(3600),
398        }
399    }
400}
401
402use crate::server::{AppState, AuthorizationCode, Client, IssuerConfig, Token, KEYS, KID};