1use base64::{engine::general_purpose, Engine};
2use jsonwebtoken::{encode, Algorithm, Header};
3use reqwest::{Client as HttpClient, StatusCode};
4use serde_json::{json, Value};
5use sha2::{Digest, Sha256};
6use std::{
7 collections::HashMap,
8 sync::{Arc, RwLock},
9};
10use tokio::task::JoinError;
11use url::Url;
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
15pub struct PkcePair {
16 pub code_verifier: String,
17 pub code_challenge: String,
18}
19
20#[derive(Debug, Default)]
21pub struct AuthorizeParams {
22 pub response_type: &'static str,
23 pub redirect_uri: String,
24 pub scope: String,
25 pub state: String,
26 pub pkce: Option<PkcePair>,
27 pub nonce: Option<String>,
28}
29
30impl AuthorizeParams {
31 pub fn new() -> Self {
32 Self {
33 response_type: "code",
34 redirect_uri: "http://localhost/cb".to_string(),
35 scope: "openid".to_string(),
36 state: Uuid::new_v4().to_string(),
37 pkce: None,
38 nonce: None,
39 }
40 }
41
42 pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
43 self.redirect_uri = uri.into();
44 self
45 }
46
47 pub fn scope(mut self, scope: impl Into<String>) -> Self {
48 self.scope = scope.into();
49 self
50 }
51
52 pub fn state(mut self, state: impl Into<String>) -> Self {
53 self.state = state.into();
54 self
55 }
56
57 pub fn pkce(mut self, pkce: PkcePair) -> Self {
58 self.pkce = Some(pkce);
59 self
60 }
61
62 pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
63 self.nonce = Some(nonce.into());
64 self
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct OauthEndpoints {
70 pub oauth_server: String,
71 pub discovery: String,
72 pub authorize: String,
73 pub token: String,
74 pub regsiter: String,
75 pub introspect: String,
76 pub revoke: String,
77 pub userinfo: String,
78 pub jwks: String,
79}
80
81pub struct OAuthTestServer {
103 state: AppState,
104 pub base_url: url::Url,
105 pub endpoints: OauthEndpoints,
106 http: HttpClient,
107 _handle: tokio::task::JoinHandle<()>,
108}
109
110impl OAuthTestServer {
111 pub async fn start() -> Self {
112 let config = IssuerConfig {
113 port: 0,
114 ..Default::default()
115 };
116 Self::start_with_config(config).await
117 }
118
119 pub fn clients(&self) -> Arc<RwLock<HashMap<String, Client>>> {
120 self.state.clients.clone()
121 }
122
123 pub fn codes(&self) -> Arc<RwLock<HashMap<String, AuthorizationCode>>> {
124 self.state.codes.clone()
125 }
126
127 pub fn tokens(&self) -> Arc<RwLock<HashMap<String, Token>>> {
128 self.state.tokens.clone()
129 }
130
131 pub fn refresh_tokens(&self) -> Arc<RwLock<HashMap<String, Token>>> {
132 self.state.refresh_tokens.clone()
133 }
134
135 pub async fn start_with_config(config: IssuerConfig) -> Self {
136 let mut state = AppState::new(config.clone());
138 let (addr, handle) = state.clone().start().await;
139 let base_url: Url = format!("http://{addr}").parse().unwrap();
140 state.base_url = base_url.to_string().trim_end_matches("/").to_string();
141 let endpoints: OauthEndpoints = OauthEndpoints {
142 oauth_server: base_url.clone().to_string(),
143 discovery: format!("{base_url}.well-known/openid-configuration"),
144 authorize: format!("{base_url}register"),
145 regsiter: format!("{base_url}authorize"),
146 token: format!("{base_url}token"),
147 introspect: format!("{base_url}introspect"),
148 revoke: format!("{base_url}revoke"),
149 userinfo: format!("{base_url}userinfo"),
150 jwks: format!("{base_url}.well-known/jwks.json"),
151 };
152
153 Self {
154 state,
155 base_url,
156 endpoints,
157 http: HttpClient::new(),
158 _handle: handle,
159 }
160 }
161
162 pub async fn wait_for_shutdown(self) -> Result<(), JoinError> {
163 self._handle.await
164 }
165
166 pub fn register_client(&self, metadata: serde_json::Value) -> Client {
167 self.state
168 .register_client(metadata)
169 .expect("client registration failed")
170 }
171
172 pub fn register_client_with_secret(&self, metadata: Value, force_secret: bool) -> Client {
173 let mut meta = metadata;
174 if let Some(obj) = meta.as_object_mut() {
175 obj.insert(
176 "generate_client_secret_for_dcr".to_string(),
177 json!(force_secret),
178 );
179 }
180 self.register_client(meta)
181 }
182
183 pub fn generate_jwt(&self, client: &Client, options: JwtOptions) -> String {
184 self.state
185 .generate_jwt(client, options)
186 .expect("JWT generation failed")
187 }
188
189 pub fn generate_token(&self, client: &Client, options: JwtOptions) -> Token {
190 self.state
191 .generate_token(client, options)
192 .expect("Token generation failed")
193 }
194
195 pub fn jwt_options(&self) -> JwtOptionsBuilder {
196 JwtOptionsBuilder::default()
197 }
198
199 pub fn pkce_pair(&self) -> PkcePair {
200 use rand::Rng;
201 let verifier_bytes: [u8; 32] = rand::thread_rng().r#gen();
202 let code_verifier = general_purpose::URL_SAFE_NO_PAD.encode(verifier_bytes);
203 let challenge =
204 general_purpose::URL_SAFE_NO_PAD.encode(Sha256::digest(code_verifier.as_bytes()));
205 PkcePair {
206 code_verifier,
207 code_challenge: challenge,
208 }
209 }
210
211 pub fn authorize_url(&self, client: &Client, params: AuthorizeParams) -> Url {
212 let mut url = self.base_url.join("authorize").unwrap();
213 let mut query = url.query_pairs_mut();
214
215 query
216 .append_pair("response_type", params.response_type)
217 .append_pair("client_id", &client.client_id)
218 .append_pair("redirect_uri", ¶ms.redirect_uri)
219 .append_pair("scope", ¶ms.scope)
220 .append_pair("state", ¶ms.state);
221
222 if let Some(pkce) = params.pkce {
223 query
224 .append_pair("code_challenge", &pkce.code_challenge)
225 .append_pair("code_challenge_method", "S256");
226 }
227
228 if let Some(nonce) = params.nonce {
229 query.append_pair("nonce", &nonce);
230 }
231
232 drop(query);
233 url
234 }
235
236 pub fn rotate_keys(&self) {
237 unimplemented!("Key rotation not implemented in test server")
239 }
240
241 pub async fn approve_consent(&self, auth_url: &Url, user_id: &str) -> String {
242 let resp = self.http.get(auth_url.clone()).send().await.unwrap();
243 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
244
245 let location = resp.headers().get("location").unwrap().to_str().unwrap();
246 let redirect = Url::parse(location).unwrap();
247 let code = redirect
248 .query_pairs()
249 .find(|(k, _)| k == "code")
250 .map(|(_, v)| v.to_string())
251 .expect("no code in redirect");
252
253 let code_obj = self
255 .state
256 .codes
257 .read()
258 .unwrap()
259 .get(&code)
260 .cloned()
261 .unwrap();
262 let mut code_obj = code_obj;
263 code_obj.user_id = user_id.to_string();
264 self.state
265 .codes
266 .write()
267 .unwrap()
268 .insert(code.clone(), code_obj);
269
270 code
271 }
272
273 pub async fn exchange_code(
274 &self,
275 client: &Client,
276 code: &str,
277 pkce: Option<&PkcePair>,
278 ) -> Value {
279 let mut form = vec![
280 ("grant_type", "authorization_code"),
281 ("code", code),
282 ("redirect_uri", "http://localhost/cb"),
283 ];
284
285 if let Some(pkce) = pkce {
286 form.push(("code_verifier", &pkce.code_verifier));
287 }
288
289 let resp = self
290 .http
291 .post(self.base_url.join("token").unwrap())
292 .basic_auth(&client.client_id, client.client_secret.as_ref())
293 .form(&form)
294 .send()
295 .await
296 .unwrap();
297
298 assert_eq!(resp.status(), StatusCode::OK);
299 resp.json().await.unwrap()
300 }
301
302 pub async fn refresh_token(&self, client: &Client, refresh_token: &str) -> Value {
303 let resp = self
304 .http
305 .post(self.base_url.join("token").unwrap())
306 .basic_auth(&client.client_id, client.client_secret.as_ref())
307 .form(&[
308 ("grant_type", "refresh_token"),
309 ("refresh_token", refresh_token),
310 ])
311 .send()
312 .await
313 .unwrap();
314
315 resp.json().await.unwrap()
316 }
317
318 pub async fn introspect_token(&self, client: &Client, token: &str) -> Value {
319 let resp = self
320 .http
321 .post(self.base_url.join("introspect").unwrap())
322 .basic_auth(&client.client_id, client.client_secret.as_ref())
323 .form(&[("token", token)])
324 .send()
325 .await
326 .unwrap();
327
328 resp.json().await.unwrap()
329 }
330
331 pub async fn revoke_token(&self, client: &Client, token: &str) {
332 let resp = self
333 .http
334 .post(self.base_url.join("revoke").unwrap())
335 .basic_auth(&client.client_id, client.client_secret.as_ref())
336 .form(&[("token", token)])
337 .send()
338 .await
339 .unwrap();
340
341 assert!(resp.status().is_success());
342 }
343
344 pub fn client_assertion_jwt(&self, client: &Client) -> String {
345 let claims = json!({
346 "iss": client.client_id,
347 "sub": client.client_id,
348 "aud": self.issuer(),
349 "exp": (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp(),
350 "iat": chrono::Utc::now().timestamp(),
351 "jti": Uuid::new_v4().to_string(),
352 });
353
354 let mut header = Header::new(Algorithm::RS256);
355 header.kid = Some(KID.to_string());
356
357 encode(&header, &claims, &KEYS.encoding).unwrap()
358 }
359
360 pub fn base_url(&self) -> &url::Url {
361 &self.base_url
362 }
363
364 pub fn issuer(&self) -> &str {
365 self.state.issuer()
366 }
367}
368
369#[derive(Debug, Default)]
370pub struct JwtOptions {
371 pub user_id: String,
372 pub scope: Option<String>,
373 pub expires_in: i64,
374}
375
376#[derive(Default)]
377pub struct JwtOptionsBuilder {
378 user_id: Option<String>,
379 scope: Option<String>,
380 expires_in: Option<i64>,
381}
382
383impl JwtOptionsBuilder {
384 pub fn user_id(mut self, user_id: impl Into<String>) -> Self {
385 self.user_id = Some(user_id.into());
386 self
387 }
388
389 pub fn scope(mut self, scope: impl Into<String>) -> Self {
390 self.scope = Some(scope.into());
391 self
392 }
393
394 pub fn expires_in(mut self, seconds: i64) -> Self {
395 self.expires_in = Some(seconds);
396 self
397 }
398
399 pub fn build(self) -> JwtOptions {
400 JwtOptions {
401 user_id: self.user_id.unwrap_or("test-user-123".to_string()),
402 scope: self.scope,
403 expires_in: self.expires_in.unwrap_or(3600),
404 }
405 }
406}
407
408use crate::server::{AppState, AuthorizationCode, Client, IssuerConfig, Token, KEYS, KID};