1use axum::response::Json;
7use jsonwebtoken::{Algorithm, EncodingKey, Header};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use mockforge_core::Error;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct OidcConfig {
19 pub enabled: bool,
21 pub issuer: String,
23 pub jwks: JwksConfig,
25 pub claims: ClaimsConfig,
27 pub multi_tenant: Option<MultiTenantConfig>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct JwksConfig {
34 pub keys: Vec<JwkKey>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct JwkKey {
41 pub kid: String,
43 pub alg: String,
45 pub public_key: String,
47 #[serde(skip_serializing)]
49 pub private_key: Option<String>,
50 pub kty: String,
52 #[serde(default = "default_key_use")]
54 pub use_: String,
55}
56
57fn default_key_use() -> String {
58 "sig".to_string()
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ClaimsConfig {
64 pub default: Vec<String>,
66 #[serde(default)]
68 pub custom: HashMap<String, serde_json::Value>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct MultiTenantConfig {
74 pub enabled: bool,
76 pub org_id_claim: String,
78 pub tenant_id_claim: Option<String>,
80}
81
82impl Default for OidcConfig {
83 fn default() -> Self {
84 Self {
85 enabled: false,
86 issuer: "https://mockforge.example.com".to_string(),
87 jwks: JwksConfig { keys: vec![] },
88 claims: ClaimsConfig {
89 default: vec!["sub".to_string(), "iss".to_string(), "exp".to_string()],
90 custom: HashMap::new(),
91 },
92 multi_tenant: None,
93 }
94 }
95}
96
97#[derive(Debug, Serialize)]
99pub struct OidcDiscoveryDocument {
100 pub issuer: String,
102 pub authorization_endpoint: String,
104 pub token_endpoint: String,
106 pub userinfo_endpoint: String,
108 pub jwks_uri: String,
110 pub response_types_supported: Vec<String>,
112 pub subject_types_supported: Vec<String>,
114 pub id_token_signing_alg_values_supported: Vec<String>,
116 pub scopes_supported: Vec<String>,
118 pub claims_supported: Vec<String>,
120 pub grant_types_supported: Vec<String>,
122}
123
124#[derive(Debug, Serialize)]
126pub struct JwksResponse {
127 pub keys: Vec<JwkPublicKey>,
129}
130
131#[derive(Debug, Serialize)]
133pub struct JwkPublicKey {
134 pub kid: String,
136 pub kty: String,
138 pub alg: String,
140 #[serde(rename = "use")]
142 pub use_: String,
143 #[serde(skip_serializing_if = "Option::is_none")]
145 pub n: Option<String>,
146 #[serde(skip_serializing_if = "Option::is_none")]
148 pub e: Option<String>,
149 #[serde(skip_serializing_if = "Option::is_none")]
151 pub crv: Option<String>,
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub x: Option<String>,
155 #[serde(skip_serializing_if = "Option::is_none")]
157 pub y: Option<String>,
158}
159
160#[derive(Clone)]
162pub struct OidcState {
163 pub config: OidcConfig,
165 pub signing_keys: Arc<RwLock<HashMap<String, EncodingKey>>>,
167}
168
169impl OidcState {
170 pub fn new(config: OidcConfig) -> Result<Self, Error> {
172 let mut signing_keys = HashMap::new();
173
174 for key in &config.jwks.keys {
176 if let Some(ref private_key) = key.private_key {
177 let encoding_key = match key.alg.as_str() {
178 "RS256" | "RS384" | "RS512" => {
179 EncodingKey::from_rsa_pem(private_key.as_bytes()).map_err(|e| {
180 Error::generic(format!("Failed to load RSA key {}: {}", key.kid, e))
181 })?
182 }
183 "ES256" | "ES384" | "ES512" => EncodingKey::from_ec_pem(private_key.as_bytes())
184 .map_err(|e| {
185 Error::generic(format!("Failed to load EC key {}: {}", key.kid, e))
186 })?,
187 "HS256" | "HS384" | "HS512" => EncodingKey::from_secret(private_key.as_bytes()),
188 _ => {
189 return Err(Error::generic(format!("Unsupported algorithm: {}", key.alg)));
190 }
191 };
192 signing_keys.insert(key.kid.clone(), encoding_key);
193 }
194 }
195
196 Ok(Self {
197 config,
198 signing_keys: Arc::new(RwLock::new(signing_keys)),
199 })
200 }
201
202 pub fn default_mock() -> Result<Self, Error> {
207 use std::env;
208
209 let issuer = env::var("MOCKFORGE_OIDC_ISSUER").unwrap_or_else(|_| {
211 env::var("MOCKFORGE_BASE_URL")
212 .unwrap_or_else(|_| "https://mockforge.example.com".to_string())
213 });
214
215 let default_secret = env::var("MOCKFORGE_OIDC_SECRET")
217 .unwrap_or_else(|_| "mockforge-default-secret-key-change-in-production".to_string());
218
219 let default_key = JwkKey {
220 kid: "default".to_string(),
221 alg: "HS256".to_string(),
222 public_key: default_secret.clone(),
223 private_key: Some(default_secret),
224 kty: "oct".to_string(),
225 use_: "sig".to_string(),
226 };
227
228 let config = OidcConfig {
229 enabled: true,
230 issuer,
231 jwks: JwksConfig {
232 keys: vec![default_key],
233 },
234 claims: ClaimsConfig {
235 default: vec!["sub".to_string(), "iss".to_string(), "exp".to_string()],
236 custom: HashMap::new(),
237 },
238 multi_tenant: None,
239 };
240
241 Self::new(config)
242 }
243}
244
245pub fn load_oidc_state() -> Option<OidcState> {
254 use std::env;
255
256 if let Ok(disabled) = env::var("MOCKFORGE_OIDC_ENABLED") {
258 if disabled == "false" || disabled == "0" {
259 return None;
260 }
261 }
262
263 if let Ok(config_json) = env::var("MOCKFORGE_OIDC_CONFIG") {
265 if let Ok(config) = serde_json::from_str::<OidcConfig>(&config_json) {
266 if config.enabled {
267 return OidcState::new(config).ok();
268 }
269 return None;
270 }
271 }
272
273 OidcState::default_mock().ok()
276}
277
278pub async fn get_oidc_discovery() -> Json<OidcDiscoveryDocument> {
280 let base_url = std::env::var("MOCKFORGE_BASE_URL")
283 .unwrap_or_else(|_| "https://mockforge.example.com".to_string());
284
285 let discovery = OidcDiscoveryDocument {
286 issuer: base_url.clone(),
287 authorization_endpoint: format!("{}/oauth2/authorize", base_url),
288 token_endpoint: format!("{}/oauth2/token", base_url),
289 userinfo_endpoint: format!("{}/oauth2/userinfo", base_url),
290 jwks_uri: format!("{}/.well-known/jwks.json", base_url),
291 response_types_supported: vec![
292 "code".to_string(),
293 "id_token".to_string(),
294 "token id_token".to_string(),
295 ],
296 subject_types_supported: vec!["public".to_string()],
297 id_token_signing_alg_values_supported: vec![
298 "RS256".to_string(),
299 "ES256".to_string(),
300 "HS256".to_string(),
301 ],
302 scopes_supported: vec![
303 "openid".to_string(),
304 "profile".to_string(),
305 "email".to_string(),
306 "address".to_string(),
307 "phone".to_string(),
308 ],
309 claims_supported: vec![
310 "sub".to_string(),
311 "iss".to_string(),
312 "aud".to_string(),
313 "exp".to_string(),
314 "iat".to_string(),
315 "auth_time".to_string(),
316 "nonce".to_string(),
317 "email".to_string(),
318 "email_verified".to_string(),
319 "name".to_string(),
320 "given_name".to_string(),
321 "family_name".to_string(),
322 ],
323 grant_types_supported: vec![
324 "authorization_code".to_string(),
325 "implicit".to_string(),
326 "refresh_token".to_string(),
327 "client_credentials".to_string(),
328 ],
329 };
330
331 Json(discovery)
332}
333
334pub async fn get_jwks() -> Json<JwksResponse> {
336 let jwks = JwksResponse { keys: vec![] };
339
340 Json(jwks)
341}
342
343pub fn get_jwks_from_state(oidc_state: &OidcState) -> Result<JwksResponse, Error> {
345 use crate::auth::jwks_converter::convert_jwk_key_simple;
346
347 let mut public_keys = Vec::new();
348
349 for key in &oidc_state.config.jwks.keys {
350 match convert_jwk_key_simple(key) {
351 Ok(jwk) => public_keys.push(jwk),
352 Err(e) => {
353 tracing::warn!("Failed to convert key {} to JWK format: {}", key.kid, e);
354 }
356 }
357 }
358
359 Ok(JwksResponse { keys: public_keys })
360}
361
362pub fn generate_signed_jwt(
373 mut claims: HashMap<String, serde_json::Value>,
374 kid: Option<String>,
375 algorithm: Algorithm,
376 encoding_key: &EncodingKey,
377 expires_in_seconds: Option<i64>,
378 issuer: Option<String>,
379 audience: Option<String>,
380) -> Result<String, Error> {
381 use chrono::Utc;
382
383 let mut header = Header::new(algorithm);
384 if let Some(kid) = kid {
385 header.kid = Some(kid);
386 }
387
388 let now = Utc::now();
390 claims.insert("iat".to_string(), json!(now.timestamp()));
391
392 if let Some(exp_seconds) = expires_in_seconds {
393 let exp = now + chrono::Duration::seconds(exp_seconds);
394 claims.insert("exp".to_string(), json!(exp.timestamp()));
395 }
396
397 if let Some(iss) = issuer {
398 claims.insert("iss".to_string(), json!(iss));
399 }
400
401 if let Some(aud) = audience {
402 claims.insert("aud".to_string(), json!(aud));
403 }
404
405 let token = jsonwebtoken::encode(&header, &claims, encoding_key)
406 .map_err(|e| Error::generic(format!("Failed to sign JWT: {}", e)))?;
407
408 Ok(token)
409}
410
411#[derive(Debug, Clone)]
413pub struct TenantContext {
414 pub org_id: Option<String>,
416 pub tenant_id: Option<String>,
418}
419
420pub fn generate_oidc_token(
422 oidc_state: &OidcState,
423 subject: String,
424 additional_claims: Option<HashMap<String, serde_json::Value>>,
425 expires_in_seconds: Option<i64>,
426 tenant_context: Option<TenantContext>,
427) -> Result<String, Error> {
428 use chrono::Utc;
429 use jsonwebtoken::Algorithm;
430
431 let mut claims = HashMap::new();
433 claims.insert("sub".to_string(), json!(subject));
434 claims.insert("iss".to_string(), json!(oidc_state.config.issuer.clone()));
435
436 for claim_name in &oidc_state.config.claims.default {
438 if !claims.contains_key(claim_name) {
439 match claim_name.as_str() {
441 "sub" | "iss" => {} "exp" => {
443 let exp_seconds = expires_in_seconds.unwrap_or(3600);
444 let exp = Utc::now() + chrono::Duration::seconds(exp_seconds);
445 claims.insert("exp".to_string(), json!(exp.timestamp()));
446 }
447 "iat" => {
448 claims.insert("iat".to_string(), json!(Utc::now().timestamp()));
449 }
450 _ => {
451 if let Some(value) = oidc_state.config.claims.custom.get(claim_name) {
453 claims.insert(claim_name.clone(), value.clone());
454 }
455 }
456 }
457 }
458 }
459
460 for (key, value) in &oidc_state.config.claims.custom {
462 if !claims.contains_key(key) {
463 claims.insert(key.clone(), value.clone());
464 }
465 }
466
467 if let Some(ref mt_config) = oidc_state.config.multi_tenant {
469 if mt_config.enabled {
470 let org_id = tenant_context
472 .as_ref()
473 .and_then(|ctx| ctx.org_id.clone())
474 .unwrap_or_else(|| "org-default".to_string());
475 let tenant_id = tenant_context
476 .as_ref()
477 .and_then(|ctx| ctx.tenant_id.clone())
478 .or_else(|| Some("tenant-default".to_string()));
479
480 claims.insert(mt_config.org_id_claim.clone(), json!(org_id));
481 if let Some(ref tenant_claim) = mt_config.tenant_id_claim {
482 if let Some(tid) = tenant_id {
483 claims.insert(tenant_claim.clone(), json!(tid));
484 }
485 }
486 }
487 }
488
489 if let Some(additional) = additional_claims {
491 for (key, value) in additional {
492 claims.insert(key, value);
493 }
494 }
495
496 let signing_keys = oidc_state.signing_keys.blocking_read();
498 let (kid, encoding_key) = signing_keys
499 .iter()
500 .next()
501 .ok_or_else(|| Error::generic("No signing keys available".to_string()))?;
502
503 let algorithm = oidc_state
506 .config
507 .jwks
508 .keys
509 .iter()
510 .find(|k| k.kid == *kid)
511 .and_then(|k| match k.alg.as_str() {
512 "RS256" => Some(Algorithm::RS256),
513 "RS384" => Some(Algorithm::RS384),
514 "RS512" => Some(Algorithm::RS512),
515 "ES256" => Some(Algorithm::ES256),
516 "ES384" => Some(Algorithm::ES384),
517 "HS256" => Some(Algorithm::HS256),
518 "HS384" => Some(Algorithm::HS384),
519 "HS512" => Some(Algorithm::HS512),
520 _ => None,
521 })
522 .unwrap_or(Algorithm::HS256);
523
524 generate_signed_jwt(
525 claims,
526 Some(kid.clone()),
527 algorithm,
528 encoding_key,
529 expires_in_seconds,
530 Some(oidc_state.config.issuer.clone()),
531 None,
532 )
533}
534
535pub fn oidc_router() -> axum::Router {
537 use axum::{routing::get, Router};
538
539 Router::new()
540 .route("/.well-known/openid-configuration", get(get_oidc_discovery))
541 .route("/.well-known/jwks.json", get(get_jwks))
542}