1use base64::{engine::general_purpose, Engine};
2use jsonwebtoken::{encode, Algorithm, Header};
3use reqwest::{Client as HttpClient, StatusCode};
4use serde_json::{json, Value};
5use sha2::{Digest, Sha256};
6use std::{
7 collections::HashMap,
8 sync::{Arc, RwLock},
9};
10use tokio::task::JoinError;
11use url::Url;
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
15pub struct PkcePair {
16 pub code_verifier: String,
17 pub code_challenge: String,
18}
19
20#[derive(Debug, Default)]
21pub struct AuthorizeParams {
22 pub response_type: &'static str,
23 pub redirect_uri: String,
24 pub scope: String,
25 pub state: String,
26 pub pkce: Option<PkcePair>,
27 pub nonce: Option<String>,
28}
29
30impl AuthorizeParams {
31 pub fn new() -> Self {
32 Self {
33 response_type: "code",
34 redirect_uri: "http://localhost/cb".to_string(),
35 scope: "openid".to_string(),
36 state: Uuid::new_v4().to_string(),
37 pkce: None,
38 nonce: None,
39 }
40 }
41
42 pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
43 self.redirect_uri = uri.into();
44 self
45 }
46
47 pub fn scope(mut self, scope: impl Into<String>) -> Self {
48 self.scope = scope.into();
49 self
50 }
51
52 pub fn state(mut self, state: impl Into<String>) -> Self {
53 self.state = state.into();
54 self
55 }
56
57 pub fn pkce(mut self, pkce: PkcePair) -> Self {
58 self.pkce = Some(pkce);
59 self
60 }
61
62 pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
63 self.nonce = Some(nonce.into());
64 self
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct OauthEndpoints {
70 pub oauth_server: String,
71 pub discovery: String,
72 pub authorize: String,
73 pub token: String,
74 pub regsiter: String,
75 pub introspect: String,
76 pub revoke: String,
77 pub userinfo: String,
78 pub jwks: String,
79}
80
81pub struct OAuthTestServer {
103 state: AppState,
104 pub base_url: url::Url,
105 pub endpoints: OauthEndpoints,
106 http: HttpClient,
107 _handle: tokio::task::JoinHandle<()>,
108}
109
110impl OAuthTestServer {
111 pub async fn start() -> Self {
112 let config = IssuerConfig {
113 port: 0,
114 ..Default::default()
115 };
116 Self::start_with_config(config).await
117 }
118
119 pub fn clients(&self) -> Arc<RwLock<HashMap<String, Client>>> {
120 self.state.clients.clone()
121 }
122
123 pub fn codes(&self) -> Arc<RwLock<HashMap<String, AuthorizationCode>>> {
124 self.state.codes.clone()
125 }
126
127 pub fn tokens(&self) -> Arc<RwLock<HashMap<String, Token>>> {
128 self.state.tokens.clone()
129 }
130
131 pub fn refresh_tokens(&self) -> Arc<RwLock<HashMap<String, Token>>> {
132 self.state.refresh_tokens.clone()
133 }
134
135 pub async fn start_with_config(config: IssuerConfig) -> Self {
136 let state = AppState::new(config.clone());
138 let (addr, handle) = state.clone().start().await;
139 let base_url: Url = format!("http://{addr}").parse().unwrap();
140
141 let endpoints: OauthEndpoints = OauthEndpoints {
142 oauth_server: base_url.clone().to_string(),
143 discovery: format!("{base_url}.well-known/openid-configuration"),
144 authorize: format!("{base_url}register"),
145 regsiter: format!("{base_url}authorize"),
146 token: format!("{base_url}token"),
147 introspect: format!("{base_url}introspect"),
148 revoke: format!("{base_url}revoke"),
149 userinfo: format!("{base_url}userinfo"),
150 jwks: format!("{base_url}.well-known/jwks.json"),
151 };
152
153 Self {
154 state,
155 base_url,
156 endpoints,
157 http: HttpClient::new(),
158 _handle: handle,
159 }
160 }
161
162 pub async fn wait_for_shutdown(self) -> Result<(), JoinError> {
163 self._handle.await
164 }
165
166 pub fn register_client(&self, metadata: serde_json::Value) -> Client {
167 self.state
168 .register_client(metadata)
169 .expect("client registration failed")
170 }
171
172 pub fn register_client_with_secret(&self, metadata: Value, force_secret: bool) -> Client {
173 let mut meta = metadata;
174 if let Some(obj) = meta.as_object_mut() {
175 obj.insert(
176 "generate_client_secret_for_dcr".to_string(),
177 json!(force_secret),
178 );
179 }
180 self.register_client(meta)
181 }
182
183 pub fn generate_jwt(&self, client: &Client, options: JwtOptions) -> String {
184 self.state
185 .generate_jwt(client, options)
186 .expect("JWT generation failed")
187 }
188
189 pub fn jwt_options(&self) -> JwtOptionsBuilder {
190 JwtOptionsBuilder::default()
191 }
192
193 pub fn pkce_pair(&self) -> PkcePair {
194 use rand::Rng;
195 let verifier_bytes: [u8; 32] = rand::thread_rng().r#gen();
196 let code_verifier = general_purpose::URL_SAFE_NO_PAD.encode(verifier_bytes);
197 let challenge =
198 general_purpose::URL_SAFE_NO_PAD.encode(Sha256::digest(code_verifier.as_bytes()));
199 PkcePair {
200 code_verifier,
201 code_challenge: challenge,
202 }
203 }
204
205 pub fn authorize_url(&self, client: &Client, params: AuthorizeParams) -> Url {
206 let mut url = self.base_url.join("authorize").unwrap();
207 let mut query = url.query_pairs_mut();
208
209 query
210 .append_pair("response_type", params.response_type)
211 .append_pair("client_id", &client.client_id)
212 .append_pair("redirect_uri", ¶ms.redirect_uri)
213 .append_pair("scope", ¶ms.scope)
214 .append_pair("state", ¶ms.state);
215
216 if let Some(pkce) = params.pkce {
217 query
218 .append_pair("code_challenge", &pkce.code_challenge)
219 .append_pair("code_challenge_method", "S256");
220 }
221
222 if let Some(nonce) = params.nonce {
223 query.append_pair("nonce", &nonce);
224 }
225
226 drop(query);
227 url
228 }
229
230 pub fn rotate_keys(&self) {
231 unimplemented!("Key rotation not implemented in test server")
233 }
234
235 pub async fn approve_consent(&self, auth_url: &Url, user_id: &str) -> String {
236 let resp = self.http.get(auth_url.clone()).send().await.unwrap();
237 assert_eq!(resp.status(), StatusCode::SEE_OTHER);
238
239 let location = resp.headers().get("location").unwrap().to_str().unwrap();
240 let redirect = Url::parse(location).unwrap();
241 let code = redirect
242 .query_pairs()
243 .find(|(k, _)| k == "code")
244 .map(|(_, v)| v.to_string())
245 .expect("no code in redirect");
246
247 let code_obj = self
249 .state
250 .codes
251 .read()
252 .unwrap()
253 .get(&code)
254 .cloned()
255 .unwrap();
256 let mut code_obj = code_obj;
257 code_obj.user_id = user_id.to_string();
258 self.state
259 .codes
260 .write()
261 .unwrap()
262 .insert(code.clone(), code_obj);
263
264 code
265 }
266
267 pub async fn exchange_code(
268 &self,
269 client: &Client,
270 code: &str,
271 pkce: Option<&PkcePair>,
272 ) -> Value {
273 let mut form = vec![
274 ("grant_type", "authorization_code"),
275 ("code", code),
276 ("redirect_uri", "http://localhost/cb"),
277 ];
278
279 if let Some(pkce) = pkce {
280 form.push(("code_verifier", &pkce.code_verifier));
281 }
282
283 let resp = self
284 .http
285 .post(self.base_url.join("token").unwrap())
286 .basic_auth(&client.client_id, client.client_secret.as_ref())
287 .form(&form)
288 .send()
289 .await
290 .unwrap();
291
292 assert_eq!(resp.status(), StatusCode::OK);
293 resp.json().await.unwrap()
294 }
295
296 pub async fn refresh_token(&self, client: &Client, refresh_token: &str) -> Value {
297 let resp = self
298 .http
299 .post(self.base_url.join("token").unwrap())
300 .basic_auth(&client.client_id, client.client_secret.as_ref())
301 .form(&[
302 ("grant_type", "refresh_token"),
303 ("refresh_token", refresh_token),
304 ])
305 .send()
306 .await
307 .unwrap();
308
309 resp.json().await.unwrap()
310 }
311
312 pub async fn introspect_token(&self, client: &Client, token: &str) -> Value {
313 let resp = self
314 .http
315 .post(self.base_url.join("introspect").unwrap())
316 .basic_auth(&client.client_id, client.client_secret.as_ref())
317 .form(&[("token", token)])
318 .send()
319 .await
320 .unwrap();
321
322 resp.json().await.unwrap()
323 }
324
325 pub async fn revoke_token(&self, client: &Client, token: &str) {
326 let resp = self
327 .http
328 .post(self.base_url.join("revoke").unwrap())
329 .basic_auth(&client.client_id, client.client_secret.as_ref())
330 .form(&[("token", token)])
331 .send()
332 .await
333 .unwrap();
334
335 assert!(resp.status().is_success());
336 }
337
338 pub fn client_assertion_jwt(&self, client: &Client) -> String {
339 let claims = json!({
340 "iss": client.client_id,
341 "sub": client.client_id,
342 "aud": self.issuer(),
343 "exp": (chrono::Utc::now() + chrono::Duration::minutes(5)).timestamp(),
344 "iat": chrono::Utc::now().timestamp(),
345 "jti": Uuid::new_v4().to_string(),
346 });
347
348 let mut header = Header::new(Algorithm::RS256);
349 header.kid = Some(KID.to_string());
350
351 encode(&header, &claims, &KEYS.encoding).unwrap()
352 }
353
354 pub fn base_url(&self) -> &url::Url {
355 &self.base_url
356 }
357
358 pub fn issuer(&self) -> &str {
359 self.state.issuer()
360 }
361}
362
363#[derive(Debug, Default)]
364pub struct JwtOptions {
365 pub user_id: String,
366 pub scope: Option<String>,
367 pub expires_in: i64,
368}
369
370#[derive(Default)]
371pub struct JwtOptionsBuilder {
372 user_id: Option<String>,
373 scope: Option<String>,
374 expires_in: Option<i64>,
375}
376
377impl JwtOptionsBuilder {
378 pub fn user_id(mut self, user_id: impl Into<String>) -> Self {
379 self.user_id = Some(user_id.into());
380 self
381 }
382
383 pub fn scope(mut self, scope: impl Into<String>) -> Self {
384 self.scope = Some(scope.into());
385 self
386 }
387
388 pub fn expires_in(mut self, seconds: i64) -> Self {
389 self.expires_in = Some(seconds);
390 self
391 }
392
393 pub fn build(self) -> JwtOptions {
394 JwtOptions {
395 user_id: self.user_id.unwrap_or("test-user-123".to_string()),
396 scope: self.scope,
397 expires_in: self.expires_in.unwrap_or(3600),
398 }
399 }
400}
401
402use crate::server::{AppState, AuthorizationCode, Client, IssuerConfig, Token, KEYS, KID};