1use axum::{
7 extract::State,
8 http::StatusCode,
9 response::Json,
10};
11use chrono::Utc;
12use jsonwebtoken::{Algorithm, EncodingKey, Header};
13use serde::{Deserialize, Serialize};
14use serde_json::json;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19use crate::auth::state::AuthState;
20use mockforge_core::Error;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct OidcConfig {
25 pub enabled: bool,
27 pub issuer: String,
29 pub jwks: JwksConfig,
31 pub claims: ClaimsConfig,
33 pub multi_tenant: Option<MultiTenantConfig>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct JwksConfig {
40 pub keys: Vec<JwkKey>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct JwkKey {
47 pub kid: String,
49 pub alg: String,
51 pub public_key: String,
53 #[serde(skip_serializing)]
55 pub private_key: Option<String>,
56 pub kty: String,
58 #[serde(default = "default_key_use")]
60 pub use_: String,
61}
62
63fn default_key_use() -> String {
64 "sig".to_string()
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ClaimsConfig {
70 pub default: Vec<String>,
72 #[serde(default)]
74 pub custom: HashMap<String, serde_json::Value>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct MultiTenantConfig {
80 pub enabled: bool,
82 pub org_id_claim: String,
84 pub tenant_id_claim: Option<String>,
86}
87
88impl Default for OidcConfig {
89 fn default() -> Self {
90 Self {
91 enabled: false,
92 issuer: "https://mockforge.example.com".to_string(),
93 jwks: JwksConfig { keys: vec![] },
94 claims: ClaimsConfig {
95 default: vec!["sub".to_string(), "iss".to_string(), "exp".to_string()],
96 custom: HashMap::new(),
97 },
98 multi_tenant: None,
99 }
100 }
101}
102
103#[derive(Debug, Serialize)]
105pub struct OidcDiscoveryDocument {
106 pub issuer: String,
108 pub authorization_endpoint: String,
110 pub token_endpoint: String,
112 pub userinfo_endpoint: String,
114 pub jwks_uri: String,
116 pub response_types_supported: Vec<String>,
118 pub subject_types_supported: Vec<String>,
120 pub id_token_signing_alg_values_supported: Vec<String>,
122 pub scopes_supported: Vec<String>,
124 pub claims_supported: Vec<String>,
126 pub grant_types_supported: Vec<String>,
128}
129
130#[derive(Debug, Serialize)]
132pub struct JwksResponse {
133 pub keys: Vec<JwkPublicKey>,
135}
136
137#[derive(Debug, Serialize)]
139pub struct JwkPublicKey {
140 pub kid: String,
142 pub kty: String,
144 pub alg: String,
146 #[serde(rename = "use")]
148 pub use_: String,
149 #[serde(skip_serializing_if = "Option::is_none")]
151 pub n: Option<String>,
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub e: Option<String>,
155 #[serde(skip_serializing_if = "Option::is_none")]
157 pub crv: Option<String>,
158 #[serde(skip_serializing_if = "Option::is_none")]
160 pub x: Option<String>,
161 #[serde(skip_serializing_if = "Option::is_none")]
163 pub y: Option<String>,
164}
165
166#[derive(Debug, Clone)]
168pub struct OidcState {
169 pub config: OidcConfig,
171 pub signing_keys: Arc<RwLock<HashMap<String, EncodingKey>>>,
173}
174
175impl OidcState {
176 pub fn new(config: OidcConfig) -> Result<Self, Error> {
178 let mut signing_keys = HashMap::new();
179
180 for key in &config.jwks.keys {
182 if let Some(ref private_key) = key.private_key {
183 let encoding_key = match key.alg.as_str() {
184 "RS256" | "RS384" | "RS512" => {
185 EncodingKey::from_rsa_pem(private_key.as_bytes())
186 .map_err(|e| Error::generic(format!("Failed to load RSA key {}: {}", key.kid, e)))?
187 }
188 "ES256" | "ES384" | "ES512" => {
189 EncodingKey::from_ec_pem(private_key.as_bytes())
190 .map_err(|e| Error::generic(format!("Failed to load EC key {}: {}", key.kid, e)))?
191 }
192 "HS256" | "HS384" | "HS512" => {
193 EncodingKey::from_secret(private_key.as_bytes())
194 }
195 _ => {
196 return Err(Error::generic(format!("Unsupported algorithm: {}", key.alg)));
197 }
198 };
199 signing_keys.insert(key.kid.clone(), encoding_key);
200 }
201 }
202
203 Ok(Self {
204 config,
205 signing_keys: Arc::new(RwLock::new(signing_keys)),
206 })
207 }
208}
209
210pub async fn get_oidc_discovery() -> Json<OidcDiscoveryDocument> {
212 let base_url = std::env::var("MOCKFORGE_BASE_URL")
215 .unwrap_or_else(|_| "https://mockforge.example.com".to_string());
216
217 let discovery = OidcDiscoveryDocument {
218 issuer: base_url.clone(),
219 authorization_endpoint: format!("{}/oauth2/authorize", base_url),
220 token_endpoint: format!("{}/oauth2/token", base_url),
221 userinfo_endpoint: format!("{}/oauth2/userinfo", base_url),
222 jwks_uri: format!("{}/.well-known/jwks.json", base_url),
223 response_types_supported: vec![
224 "code".to_string(),
225 "id_token".to_string(),
226 "token id_token".to_string(),
227 ],
228 subject_types_supported: vec!["public".to_string()],
229 id_token_signing_alg_values_supported: vec![
230 "RS256".to_string(),
231 "ES256".to_string(),
232 "HS256".to_string(),
233 ],
234 scopes_supported: vec![
235 "openid".to_string(),
236 "profile".to_string(),
237 "email".to_string(),
238 "address".to_string(),
239 "phone".to_string(),
240 ],
241 claims_supported: vec![
242 "sub".to_string(),
243 "iss".to_string(),
244 "aud".to_string(),
245 "exp".to_string(),
246 "iat".to_string(),
247 "auth_time".to_string(),
248 "nonce".to_string(),
249 "email".to_string(),
250 "email_verified".to_string(),
251 "name".to_string(),
252 "given_name".to_string(),
253 "family_name".to_string(),
254 ],
255 grant_types_supported: vec![
256 "authorization_code".to_string(),
257 "implicit".to_string(),
258 "refresh_token".to_string(),
259 "client_credentials".to_string(),
260 ],
261 };
262
263 Json(discovery)
264}
265
266pub async fn get_jwks() -> Json<JwksResponse> {
268 let jwks = JwksResponse {
271 keys: vec![],
272 };
273
274 Json(jwks)
275}
276
277pub fn get_jwks_from_state(oidc_state: &OidcState) -> Result<JwksResponse, Error> {
279 use crate::auth::jwks_converter::convert_jwk_key_simple;
280
281 let mut public_keys = Vec::new();
282
283 for key in &oidc_state.config.jwks.keys {
284 match convert_jwk_key_simple(key) {
285 Ok(jwk) => public_keys.push(jwk),
286 Err(e) => {
287 tracing::warn!("Failed to convert key {} to JWK format: {}", key.kid, e);
288 }
290 }
291 }
292
293 Ok(JwksResponse { keys: public_keys })
294}
295
296pub fn generate_signed_jwt(
307 mut claims: HashMap<String, serde_json::Value>,
308 kid: Option<String>,
309 algorithm: Algorithm,
310 encoding_key: &EncodingKey,
311 expires_in_seconds: Option<i64>,
312 issuer: Option<String>,
313 audience: Option<String>,
314) -> Result<String, Error> {
315 use chrono::Utc;
316
317 let mut header = Header::new(algorithm);
318 if let Some(kid) = kid {
319 header.kid = Some(kid);
320 }
321
322 let now = Utc::now();
324 claims.insert("iat".to_string(), json!(now.timestamp()));
325
326 if let Some(exp_seconds) = expires_in_seconds {
327 let exp = now + chrono::Duration::seconds(exp_seconds);
328 claims.insert("exp".to_string(), json!(exp.timestamp()));
329 }
330
331 if let Some(iss) = issuer {
332 claims.insert("iss".to_string(), json!(iss));
333 }
334
335 if let Some(aud) = audience {
336 claims.insert("aud".to_string(), json!(aud));
337 }
338
339 let token = jsonwebtoken::encode(&header, &claims, encoding_key)
340 .map_err(|e| Error::generic(format!("Failed to sign JWT: {}", e)))?;
341
342 Ok(token)
343}
344
345#[derive(Debug, Clone)]
347pub struct TenantContext {
348 pub org_id: Option<String>,
350 pub tenant_id: Option<String>,
352}
353
354pub fn generate_oidc_token(
356 oidc_state: &OidcState,
357 subject: String,
358 additional_claims: Option<HashMap<String, serde_json::Value>>,
359 expires_in_seconds: Option<i64>,
360 tenant_context: Option<TenantContext>,
361) -> Result<String, Error> {
362 use chrono::Utc;
363 use jsonwebtoken::Algorithm;
364
365 let mut claims = HashMap::new();
367 claims.insert("sub".to_string(), json!(subject));
368 claims.insert("iss".to_string(), json!(oidc_state.config.issuer.clone()));
369
370 for claim_name in &oidc_state.config.claims.default {
372 if !claims.contains_key(claim_name) {
373 match claim_name.as_str() {
375 "sub" | "iss" => {} "exp" => {
377 let exp_seconds = expires_in_seconds.unwrap_or(3600);
378 let exp = Utc::now() + chrono::Duration::seconds(exp_seconds);
379 claims.insert("exp".to_string(), json!(exp.timestamp()));
380 }
381 "iat" => {
382 claims.insert("iat".to_string(), json!(Utc::now().timestamp()));
383 }
384 _ => {
385 if let Some(value) = oidc_state.config.claims.custom.get(claim_name) {
387 claims.insert(claim_name.clone(), value.clone());
388 }
389 }
390 }
391 }
392 }
393
394 for (key, value) in &oidc_state.config.claims.custom {
396 if !claims.contains_key(key) {
397 claims.insert(key.clone(), value.clone());
398 }
399 }
400
401 if let Some(ref mt_config) = oidc_state.config.multi_tenant {
403 if mt_config.enabled {
404 let org_id = tenant_context
406 .as_ref()
407 .and_then(|ctx| ctx.org_id.clone())
408 .unwrap_or_else(|| "org-default".to_string());
409 let tenant_id = tenant_context
410 .as_ref()
411 .and_then(|ctx| ctx.tenant_id.clone())
412 .or_else(|| Some("tenant-default".to_string()));
413
414 claims.insert(mt_config.org_id_claim.clone(), json!(org_id));
415 if let Some(ref tenant_claim) = mt_config.tenant_id_claim {
416 if let Some(tid) = tenant_id {
417 claims.insert(tenant_claim.clone(), json!(tid));
418 }
419 }
420 }
421 }
422
423 if let Some(additional) = additional_claims {
425 for (key, value) in additional {
426 claims.insert(key, value);
427 }
428 }
429
430 let signing_keys = oidc_state.signing_keys.blocking_read();
432 let (kid, encoding_key) = signing_keys
433 .iter()
434 .next()
435 .ok_or_else(|| Error::generic("No signing keys available".to_string()))?;
436
437 let algorithm = oidc_state
440 .config
441 .jwks
442 .keys
443 .iter()
444 .find(|k| k.kid == *kid)
445 .and_then(|k| match k.alg.as_str() {
446 "RS256" => Some(Algorithm::RS256),
447 "RS384" => Some(Algorithm::RS384),
448 "RS512" => Some(Algorithm::RS512),
449 "ES256" => Some(Algorithm::ES256),
450 "ES384" => Some(Algorithm::ES384),
451 "ES512" => Some(Algorithm::ES512),
452 "HS256" => Some(Algorithm::HS256),
453 "HS384" => Some(Algorithm::HS384),
454 "HS512" => Some(Algorithm::HS512),
455 _ => None,
456 })
457 .unwrap_or(Algorithm::HS256);
458
459 generate_signed_jwt(
460 claims,
461 Some(kid.clone()),
462 algorithm,
463 encoding_key,
464 expires_in_seconds,
465 Some(oidc_state.config.issuer.clone()),
466 None,
467 )
468}
469
470pub fn oidc_router() -> axum::Router {
472 use axum::routing::get;
473
474 axum::Router::new()
475 .route("/.well-known/openid-configuration", get(get_oidc_discovery))
476 .route("/.well-known/jwks.json", get(get_jwks))
477}
478