1use axum::{extract::State, http::StatusCode, response::Json};
7use chrono::{Duration, Utc};
8use jsonwebtoken::{Algorithm, EncodingKey, Header};
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15use crate::auth::state::AuthState;
16use mockforge_core::Error;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct OidcConfig {
21 pub enabled: bool,
23 pub issuer: String,
25 pub jwks: JwksConfig,
27 pub claims: ClaimsConfig,
29 pub multi_tenant: Option<MultiTenantConfig>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct JwksConfig {
36 pub keys: Vec<JwkKey>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct JwkKey {
43 pub kid: String,
45 pub alg: String,
47 pub public_key: String,
49 #[serde(skip_serializing)]
51 pub private_key: Option<String>,
52 pub kty: String,
54 #[serde(default = "default_key_use")]
56 pub use_: String,
57}
58
59fn default_key_use() -> String {
60 "sig".to_string()
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ClaimsConfig {
66 pub default: Vec<String>,
68 #[serde(default)]
70 pub custom: HashMap<String, serde_json::Value>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct MultiTenantConfig {
76 pub enabled: bool,
78 pub org_id_claim: String,
80 pub tenant_id_claim: Option<String>,
82}
83
84impl Default for OidcConfig {
85 fn default() -> Self {
86 Self {
87 enabled: false,
88 issuer: "https://mockforge.example.com".to_string(),
89 jwks: JwksConfig { keys: vec![] },
90 claims: ClaimsConfig {
91 default: vec!["sub".to_string(), "iss".to_string(), "exp".to_string()],
92 custom: HashMap::new(),
93 },
94 multi_tenant: None,
95 }
96 }
97}
98
99#[derive(Debug, Serialize)]
101pub struct OidcDiscoveryDocument {
102 pub issuer: String,
104 pub authorization_endpoint: String,
106 pub token_endpoint: String,
108 pub userinfo_endpoint: String,
110 pub jwks_uri: String,
112 pub response_types_supported: Vec<String>,
114 pub subject_types_supported: Vec<String>,
116 pub id_token_signing_alg_values_supported: Vec<String>,
118 pub scopes_supported: Vec<String>,
120 pub claims_supported: Vec<String>,
122 pub grant_types_supported: Vec<String>,
124}
125
126#[derive(Debug, Serialize)]
128pub struct JwksResponse {
129 pub keys: Vec<JwkPublicKey>,
131}
132
133#[derive(Debug, Serialize)]
135pub struct JwkPublicKey {
136 pub kid: String,
138 pub kty: String,
140 pub alg: String,
142 #[serde(rename = "use")]
144 pub use_: String,
145 #[serde(skip_serializing_if = "Option::is_none")]
147 pub n: Option<String>,
148 #[serde(skip_serializing_if = "Option::is_none")]
150 pub e: Option<String>,
151 #[serde(skip_serializing_if = "Option::is_none")]
153 pub crv: Option<String>,
154 #[serde(skip_serializing_if = "Option::is_none")]
156 pub x: Option<String>,
157 #[serde(skip_serializing_if = "Option::is_none")]
159 pub y: Option<String>,
160}
161
162#[derive(Clone)]
164pub struct OidcState {
165 pub config: OidcConfig,
167 pub signing_keys: Arc<RwLock<HashMap<String, EncodingKey>>>,
169}
170
171impl OidcState {
172 pub fn new(config: OidcConfig) -> Result<Self, Error> {
174 let mut signing_keys = HashMap::new();
175
176 for key in &config.jwks.keys {
178 if let Some(ref private_key) = key.private_key {
179 let encoding_key = match key.alg.as_str() {
180 "RS256" | "RS384" | "RS512" => {
181 EncodingKey::from_rsa_pem(private_key.as_bytes()).map_err(|e| {
182 Error::generic(format!("Failed to load RSA key {}: {}", key.kid, e))
183 })?
184 }
185 "ES256" | "ES384" | "ES512" => EncodingKey::from_ec_pem(private_key.as_bytes())
186 .map_err(|e| {
187 Error::generic(format!("Failed to load EC key {}: {}", key.kid, e))
188 })?,
189 "HS256" | "HS384" | "HS512" => EncodingKey::from_secret(private_key.as_bytes()),
190 _ => {
191 return Err(Error::generic(format!("Unsupported algorithm: {}", key.alg)));
192 }
193 };
194 signing_keys.insert(key.kid.clone(), encoding_key);
195 }
196 }
197
198 Ok(Self {
199 config,
200 signing_keys: Arc::new(RwLock::new(signing_keys)),
201 })
202 }
203
204 pub fn default_mock() -> Result<Self, Error> {
209 use std::env;
210
211 let issuer = env::var("MOCKFORGE_OIDC_ISSUER").unwrap_or_else(|_| {
213 env::var("MOCKFORGE_BASE_URL")
214 .unwrap_or_else(|_| "https://mockforge.example.com".to_string())
215 });
216
217 let default_secret = env::var("MOCKFORGE_OIDC_SECRET")
219 .unwrap_or_else(|_| "mockforge-default-secret-key-change-in-production".to_string());
220
221 let default_key = JwkKey {
222 kid: "default".to_string(),
223 alg: "HS256".to_string(),
224 public_key: default_secret.clone(),
225 private_key: Some(default_secret),
226 kty: "oct".to_string(),
227 use_: "sig".to_string(),
228 };
229
230 let config = OidcConfig {
231 enabled: true,
232 issuer,
233 jwks: JwksConfig {
234 keys: vec![default_key],
235 },
236 claims: ClaimsConfig {
237 default: vec!["sub".to_string(), "iss".to_string(), "exp".to_string()],
238 custom: HashMap::new(),
239 },
240 multi_tenant: None,
241 };
242
243 Self::new(config)
244 }
245}
246
247pub fn load_oidc_state() -> Option<OidcState> {
256 use std::env;
257
258 if let Ok(disabled) = env::var("MOCKFORGE_OIDC_ENABLED") {
260 if disabled == "false" || disabled == "0" {
261 return None;
262 }
263 }
264
265 if let Ok(config_json) = env::var("MOCKFORGE_OIDC_CONFIG") {
267 if let Ok(config) = serde_json::from_str::<OidcConfig>(&config_json) {
268 if config.enabled {
269 return OidcState::new(config).ok();
270 }
271 return None;
272 }
273 }
274
275 OidcState::default_mock().ok()
278}
279
280pub async fn get_oidc_discovery() -> Json<OidcDiscoveryDocument> {
282 let base_url = std::env::var("MOCKFORGE_BASE_URL")
285 .unwrap_or_else(|_| "https://mockforge.example.com".to_string());
286
287 let discovery = OidcDiscoveryDocument {
288 issuer: base_url.clone(),
289 authorization_endpoint: format!("{}/oauth2/authorize", base_url),
290 token_endpoint: format!("{}/oauth2/token", base_url),
291 userinfo_endpoint: format!("{}/oauth2/userinfo", base_url),
292 jwks_uri: format!("{}/.well-known/jwks.json", base_url),
293 response_types_supported: vec![
294 "code".to_string(),
295 "id_token".to_string(),
296 "token id_token".to_string(),
297 ],
298 subject_types_supported: vec!["public".to_string()],
299 id_token_signing_alg_values_supported: vec![
300 "RS256".to_string(),
301 "ES256".to_string(),
302 "HS256".to_string(),
303 ],
304 scopes_supported: vec![
305 "openid".to_string(),
306 "profile".to_string(),
307 "email".to_string(),
308 "address".to_string(),
309 "phone".to_string(),
310 ],
311 claims_supported: vec![
312 "sub".to_string(),
313 "iss".to_string(),
314 "aud".to_string(),
315 "exp".to_string(),
316 "iat".to_string(),
317 "auth_time".to_string(),
318 "nonce".to_string(),
319 "email".to_string(),
320 "email_verified".to_string(),
321 "name".to_string(),
322 "given_name".to_string(),
323 "family_name".to_string(),
324 ],
325 grant_types_supported: vec![
326 "authorization_code".to_string(),
327 "implicit".to_string(),
328 "refresh_token".to_string(),
329 "client_credentials".to_string(),
330 ],
331 };
332
333 Json(discovery)
334}
335
336pub async fn get_jwks() -> Json<JwksResponse> {
338 let jwks = JwksResponse { keys: vec![] };
341
342 Json(jwks)
343}
344
345pub fn get_jwks_from_state(oidc_state: &OidcState) -> Result<JwksResponse, Error> {
347 use crate::auth::jwks_converter::convert_jwk_key_simple;
348
349 let mut public_keys = Vec::new();
350
351 for key in &oidc_state.config.jwks.keys {
352 match convert_jwk_key_simple(key) {
353 Ok(jwk) => public_keys.push(jwk),
354 Err(e) => {
355 tracing::warn!("Failed to convert key {} to JWK format: {}", key.kid, e);
356 }
358 }
359 }
360
361 Ok(JwksResponse { keys: public_keys })
362}
363
364pub fn generate_signed_jwt(
375 mut claims: HashMap<String, serde_json::Value>,
376 kid: Option<String>,
377 algorithm: Algorithm,
378 encoding_key: &EncodingKey,
379 expires_in_seconds: Option<i64>,
380 issuer: Option<String>,
381 audience: Option<String>,
382) -> Result<String, Error> {
383 use chrono::Utc;
384
385 let mut header = Header::new(algorithm);
386 if let Some(kid) = kid {
387 header.kid = Some(kid);
388 }
389
390 let now = Utc::now();
392 claims.insert("iat".to_string(), json!(now.timestamp()));
393
394 if let Some(exp_seconds) = expires_in_seconds {
395 let exp = now + chrono::Duration::seconds(exp_seconds);
396 claims.insert("exp".to_string(), json!(exp.timestamp()));
397 }
398
399 if let Some(iss) = issuer {
400 claims.insert("iss".to_string(), json!(iss));
401 }
402
403 if let Some(aud) = audience {
404 claims.insert("aud".to_string(), json!(aud));
405 }
406
407 let token = jsonwebtoken::encode(&header, &claims, encoding_key)
408 .map_err(|e| Error::generic(format!("Failed to sign JWT: {}", e)))?;
409
410 Ok(token)
411}
412
413#[derive(Debug, Clone)]
415pub struct TenantContext {
416 pub org_id: Option<String>,
418 pub tenant_id: Option<String>,
420}
421
422pub fn generate_oidc_token(
424 oidc_state: &OidcState,
425 subject: String,
426 additional_claims: Option<HashMap<String, serde_json::Value>>,
427 expires_in_seconds: Option<i64>,
428 tenant_context: Option<TenantContext>,
429) -> Result<String, Error> {
430 use chrono::Utc;
431 use jsonwebtoken::Algorithm;
432
433 let mut claims = HashMap::new();
435 claims.insert("sub".to_string(), json!(subject));
436 claims.insert("iss".to_string(), json!(oidc_state.config.issuer.clone()));
437
438 for claim_name in &oidc_state.config.claims.default {
440 if !claims.contains_key(claim_name) {
441 match claim_name.as_str() {
443 "sub" | "iss" => {} "exp" => {
445 let exp_seconds = expires_in_seconds.unwrap_or(3600);
446 let exp = Utc::now() + chrono::Duration::seconds(exp_seconds);
447 claims.insert("exp".to_string(), json!(exp.timestamp()));
448 }
449 "iat" => {
450 claims.insert("iat".to_string(), json!(Utc::now().timestamp()));
451 }
452 _ => {
453 if let Some(value) = oidc_state.config.claims.custom.get(claim_name) {
455 claims.insert(claim_name.clone(), value.clone());
456 }
457 }
458 }
459 }
460 }
461
462 for (key, value) in &oidc_state.config.claims.custom {
464 if !claims.contains_key(key) {
465 claims.insert(key.clone(), value.clone());
466 }
467 }
468
469 if let Some(ref mt_config) = oidc_state.config.multi_tenant {
471 if mt_config.enabled {
472 let org_id = tenant_context
474 .as_ref()
475 .and_then(|ctx| ctx.org_id.clone())
476 .unwrap_or_else(|| "org-default".to_string());
477 let tenant_id = tenant_context
478 .as_ref()
479 .and_then(|ctx| ctx.tenant_id.clone())
480 .or_else(|| Some("tenant-default".to_string()));
481
482 claims.insert(mt_config.org_id_claim.clone(), json!(org_id));
483 if let Some(ref tenant_claim) = mt_config.tenant_id_claim {
484 if let Some(tid) = tenant_id {
485 claims.insert(tenant_claim.clone(), json!(tid));
486 }
487 }
488 }
489 }
490
491 if let Some(additional) = additional_claims {
493 for (key, value) in additional {
494 claims.insert(key, value);
495 }
496 }
497
498 let signing_keys = oidc_state.signing_keys.blocking_read();
500 let (kid, encoding_key) = signing_keys
501 .iter()
502 .next()
503 .ok_or_else(|| Error::generic("No signing keys available".to_string()))?;
504
505 let algorithm = oidc_state
508 .config
509 .jwks
510 .keys
511 .iter()
512 .find(|k| k.kid == *kid)
513 .and_then(|k| match k.alg.as_str() {
514 "RS256" => Some(Algorithm::RS256),
515 "RS384" => Some(Algorithm::RS384),
516 "RS512" => Some(Algorithm::RS512),
517 "ES256" => Some(Algorithm::ES256),
518 "ES384" => Some(Algorithm::ES384),
519 "HS256" => Some(Algorithm::HS256),
520 "HS384" => Some(Algorithm::HS384),
521 "HS512" => Some(Algorithm::HS512),
522 _ => None,
523 })
524 .unwrap_or(Algorithm::HS256);
525
526 generate_signed_jwt(
527 claims,
528 Some(kid.clone()),
529 algorithm,
530 encoding_key,
531 expires_in_seconds,
532 Some(oidc_state.config.issuer.clone()),
533 None,
534 )
535}
536
537pub fn oidc_router() -> axum::Router {
539 use axum::{routing::get, Router};
540
541 Router::new()
542 .route("/.well-known/openid-configuration", get(get_oidc_discovery))
543 .route("/.well-known/jwks.json", get(get_jwks))
544}