1use crate::{
2 auth_mode::{AuthModeConfig, PluginRegistry},
3 config_error::ConfigError,
4 dispatcher::AuthDispatcher,
5 plugins::{GenericOidcPlugin, KeycloakClaimsPlugin},
6 providers::JwksKeyProvider,
7 validation::ValidationConfig,
8};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AuthConfig {
17 #[serde(flatten)]
19 pub mode: AuthModeConfig,
20
21 #[serde(default = "default_leeway")]
23 pub leeway_seconds: i64,
24
25 #[serde(default)]
27 pub issuers: Vec<String>,
28
29 #[serde(default)]
31 pub audiences: Vec<String>,
32
33 #[serde(default)]
35 pub jwks: Option<JwksConfig>,
36
37 #[serde(default)]
39 pub plugins: HashMap<String, PluginConfig>,
40}
41
42fn default_leeway() -> i64 {
43 60
44}
45
46impl Default for AuthConfig {
47 fn default() -> Self {
48 Self {
49 mode: AuthModeConfig::default(),
50 leeway_seconds: 60,
51 issuers: Vec::new(),
52 audiences: Vec::new(),
53 jwks: None,
54 plugins: HashMap::default(),
55 }
56 }
57}
58
59impl AuthConfig {
60 pub fn validate(&self) -> Result<(), ConfigError> {
65 if !self.plugins.contains_key(&self.mode.provider) {
66 return Err(ConfigError::UnknownPlugin(self.mode.provider.clone()));
67 }
68 Ok(())
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct JwksConfig {
75 pub uri: String,
77
78 #[serde(default = "default_refresh_interval")]
80 pub refresh_interval_seconds: u64,
81
82 #[serde(default = "default_max_backoff")]
84 pub max_backoff_seconds: u64,
85}
86
87fn default_refresh_interval() -> u64 {
88 300
89}
90
91fn default_max_backoff() -> u64 {
92 3600
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97#[serde(tag = "type", rename_all = "lowercase")]
98pub enum PluginConfig {
99 Keycloak {
100 #[serde(default = "default_tenant_claim")]
102 tenant_claim: String,
103
104 client_roles: Option<String>,
106
107 role_prefix: Option<String>,
109 },
110 Oidc {
111 #[serde(default = "default_tenant_claim")]
113 tenant_claim: String,
114
115 #[serde(default = "default_roles_claim")]
117 roles_claim: String,
118 },
119}
120
121fn default_tenant_claim() -> String {
122 "tenants".to_owned()
123}
124
125fn default_roles_claim() -> String {
126 "roles".to_owned()
127}
128
129pub fn build_auth_dispatcher(config: &AuthConfig) -> Result<AuthDispatcher, ConfigError> {
134 config.validate()?;
135
136 let validation_config = ValidationConfig {
137 allowed_issuers: config.issuers.clone(),
138 allowed_audiences: config.audiences.clone(),
139 leeway_seconds: config.leeway_seconds,
140 require_uuid_subject: true,
141 require_uuid_tenants: true,
142 };
143
144 let registry = config
145 .plugins
146 .iter()
147 .map(|(name, plugin_config)| {
148 let plugin: Arc<dyn crate::plugin_traits::ClaimsPlugin> = match plugin_config {
149 PluginConfig::Keycloak {
150 tenant_claim,
151 client_roles,
152 role_prefix,
153 } => Arc::new(KeycloakClaimsPlugin::new(
154 tenant_claim,
155 client_roles.clone(),
156 role_prefix.clone(),
157 )),
158 PluginConfig::Oidc {
159 tenant_claim,
160 roles_claim,
161 } => Arc::new(GenericOidcPlugin::new(tenant_claim, roles_claim)),
162 };
163
164 tracing::debug!(
165 plugin_name = %name,
166 plugin_type = ?plugin_config,
167 "Registered claims plugin"
168 );
169
170 (name, plugin)
171 })
172 .fold(PluginRegistry::default(), |mut registry, (name, plugin)| {
173 registry.register(name, plugin);
174 registry
175 });
176
177 let dispatcher = AuthDispatcher::new(validation_config, config, ®istry)?;
178
179 let dispatcher = if let Some(jwks_config) = &config.jwks {
180 let provider = JwksKeyProvider::new(&jwks_config.uri)?
181 .with_refresh_interval(Duration::from_secs(jwks_config.refresh_interval_seconds))
182 .with_max_backoff(Duration::from_secs(jwks_config.max_backoff_seconds));
183
184 dispatcher.with_key_provider(Arc::new(provider))
185 } else {
186 dispatcher
187 };
188
189 tracing::info!(
190 plugin = %config.mode.provider,
191 "Authentication dispatcher initialized (single mode)"
192 );
193
194 Ok(dispatcher)
195}
196
197#[cfg(test)]
198#[cfg_attr(coverage_nightly, coverage(off))]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_default_config() {
204 let config = AuthConfig::default();
205 assert_eq!(config.leeway_seconds, 60);
206 assert!(config.issuers.is_empty());
207 assert!(config.audiences.is_empty());
208 }
209
210 #[test]
211 fn test_single_mode_config() {
212 let mut plugins = HashMap::new();
213 plugins.insert(
214 "keycloak".to_owned(),
215 PluginConfig::Keycloak {
216 tenant_claim: "tenants".to_owned(),
217 client_roles: Some("modkit-api".to_owned()),
218 role_prefix: None,
219 },
220 );
221
222 let config = AuthConfig {
223 mode: AuthModeConfig {
224 provider: "keycloak".to_owned(),
225 },
226 leeway_seconds: 60,
227 issuers: vec!["https://auth.example.com".to_owned()],
228 audiences: vec!["api".to_owned()],
229 jwks: None,
230 plugins,
231 };
232
233 assert!(config.validate().is_ok());
235 }
236
237 #[test]
238 fn test_single_mode_unknown_plugin() {
239 let config = AuthConfig {
240 mode: AuthModeConfig {
241 provider: "unknown".to_owned(),
242 },
243 plugins: HashMap::new(),
244 ..Default::default()
245 };
246
247 let result = config.validate();
249 assert!(matches!(result, Err(ConfigError::UnknownPlugin(_))));
250 }
251
252 #[test]
253 fn test_config_serialization() {
254 let mut plugins = HashMap::new();
255 plugins.insert(
256 "keycloak".to_owned(),
257 PluginConfig::Keycloak {
258 tenant_claim: "tenants".to_owned(),
259 client_roles: Some("modkit-api".to_owned()),
260 role_prefix: Some("kc".to_owned()),
261 },
262 );
263
264 let config = AuthConfig {
265 mode: AuthModeConfig {
266 provider: "keycloak".to_owned(),
267 },
268 leeway_seconds: 120,
269 issuers: vec!["https://auth.example.com".to_owned()],
270 audiences: vec!["api".to_owned()],
271 jwks: Some(JwksConfig {
272 uri: "https://auth.example.com/.well-known/jwks.json".to_owned(),
273 refresh_interval_seconds: 300,
274 max_backoff_seconds: 3600,
275 }),
276 plugins,
277 };
278
279 let json = serde_json::to_string_pretty(&config).unwrap();
280 println!("{json}");
281
282 let deserialized: AuthConfig = serde_json::from_str(&json).unwrap();
283 assert_eq!(deserialized.leeway_seconds, 120);
284 assert_eq!(deserialized.issuers.len(), 1);
285 }
286
287 #[test]
288 fn test_build_dispatcher_with_jwks() {
289 let mut plugins = HashMap::new();
290 plugins.insert(
291 "oidc".to_owned(),
292 PluginConfig::Oidc {
293 tenant_claim: "tenants".to_owned(),
294 roles_claim: "roles".to_owned(),
295 },
296 );
297
298 let config = AuthConfig {
299 mode: AuthModeConfig {
300 provider: "oidc".to_owned(),
301 },
302 leeway_seconds: 60,
303 issuers: vec!["https://auth.example.com".to_owned()],
304 audiences: vec!["api".to_owned()],
305 jwks: Some(JwksConfig {
306 uri: "https://auth.example.com/.well-known/jwks.json".to_owned(),
307 refresh_interval_seconds: 300,
308 max_backoff_seconds: 3600,
309 }),
310 plugins,
311 };
312
313 let dispatcher = build_auth_dispatcher(&config).unwrap();
314 assert_eq!(
315 dispatcher.validation_config().allowed_issuers,
316 vec!["https://auth.example.com"]
317 );
318 }
319}