1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use axum::Router;
6use minijinja::Environment;
7
8use allowthem_core::{AllowThem, LifecycleEventSender, OAuthProvider};
9
10use crate::browser_templates::build_default_browser_env;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum RouteGroup {
15 Login,
16 Register,
17 Logout,
18 Settings,
19 Consent,
20 PasswordReset,
21 Mfa,
22 OAuth,
23 Token,
24 UserInfo,
25 WellKnown,
26}
27
28#[derive(Debug)]
30pub enum AllRoutesError {
31 NoRoutesSelected,
32 MissingConfig(String),
33 InvalidSchema(String),
34}
35
36impl fmt::Display for AllRoutesError {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 Self::NoRoutesSelected => f.write_str("no route groups selected"),
40 Self::MissingConfig(msg) => write!(f, "missing config: {msg}"),
41 Self::InvalidSchema(msg) => write!(f, "invalid custom fields schema: {msg}"),
42 }
43 }
44}
45
46impl std::error::Error for AllRoutesError {}
47
48pub struct AllRoutesBuilder {
51 templates: Option<Arc<Environment<'static>>>,
53 is_production: bool,
54 base_url: Option<String>,
55
56 max_login_attempts: u32,
58 rate_limit_window_secs: u64,
59 oauth_providers_list: Option<Vec<String>>,
60 login_overrides: Option<crate::login_routes::LoginOverrides>,
61
62 oauth_provider_impls: Option<HashMap<String, Box<dyn OAuthProvider>>>,
64
65 mfa_issuer: Option<String>,
67
68 custom_fields_schema: Option<serde_json::Value>,
70 public_registration: bool,
71
72 events_tx: Option<LifecycleEventSender>,
74
75 routes: HashSet<RouteGroup>,
77 all: bool,
78
79 cors_enabled: bool,
81
82 default_branding: Option<allowthem_core::applications::BrandingConfig>,
84}
85
86impl Default for AllRoutesBuilder {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl AllRoutesBuilder {
93 pub fn new() -> Self {
94 Self {
95 templates: None,
96 is_production: false,
97 base_url: None,
98 max_login_attempts: 10,
99 rate_limit_window_secs: 900,
100 oauth_providers_list: None,
101 login_overrides: None,
102 oauth_provider_impls: None,
103 mfa_issuer: None,
104 custom_fields_schema: None,
105 public_registration: true,
106 events_tx: None,
107 routes: HashSet::new(),
108 all: false,
109 cors_enabled: false,
110 default_branding: None,
111 }
112 }
113
114 pub fn templates(mut self, templates: Arc<Environment<'static>>) -> Self {
117 self.templates = Some(templates);
118 self
119 }
120
121 pub fn is_production(mut self, is_production: bool) -> Self {
122 self.is_production = is_production;
123 self
124 }
125
126 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
127 self.base_url = Some(base_url.into());
128 self
129 }
130
131 pub fn events(mut self, tx: LifecycleEventSender) -> Self {
136 self.events_tx = Some(tx);
137 self
138 }
139
140 pub fn default_branding(
145 mut self,
146 branding: allowthem_core::applications::BrandingConfig,
147 ) -> Self {
148 self.default_branding = Some(branding);
149 self
150 }
151
152 pub fn max_login_attempts(mut self, max: u32) -> Self {
155 self.max_login_attempts = max;
156 self
157 }
158
159 pub fn rate_limit_window_secs(mut self, secs: u64) -> Self {
160 self.rate_limit_window_secs = secs;
161 self
162 }
163
164 pub fn oauth_providers_list(mut self, providers: Vec<String>) -> Self {
165 self.oauth_providers_list = Some(providers);
166 self
167 }
168
169 pub fn login_overrides(mut self, overrides: crate::login_routes::LoginOverrides) -> Self {
170 self.login_overrides = Some(overrides);
171 self
172 }
173
174 pub fn oauth_providers(mut self, providers: HashMap<String, Box<dyn OAuthProvider>>) -> Self {
177 self.oauth_provider_impls = Some(providers);
178 self
179 }
180
181 pub fn mfa_issuer(mut self, issuer: impl Into<String>) -> Self {
184 self.mfa_issuer = Some(issuer.into());
185 self
186 }
187
188 pub fn custom_fields_schema(mut self, schema: serde_json::Value) -> Self {
191 self.custom_fields_schema = Some(schema);
192 self
193 }
194
195 pub fn public_registration(mut self, enabled: bool) -> Self {
196 self.public_registration = enabled;
197 self
198 }
199
200 pub fn login(mut self) -> Self {
203 self.routes.insert(RouteGroup::Login);
204 self
205 }
206
207 pub fn register(mut self) -> Self {
208 self.routes.insert(RouteGroup::Register);
209 self
210 }
211
212 pub fn logout(mut self) -> Self {
213 self.routes.insert(RouteGroup::Logout);
214 self
215 }
216
217 pub fn settings(mut self) -> Self {
218 self.routes.insert(RouteGroup::Settings);
219 self
220 }
221
222 pub fn consent(mut self) -> Self {
223 self.routes.insert(RouteGroup::Consent);
224 self
225 }
226
227 pub fn password_reset(mut self) -> Self {
228 self.routes.insert(RouteGroup::PasswordReset);
229 self
230 }
231
232 pub fn mfa(mut self) -> Self {
233 self.routes.insert(RouteGroup::Mfa);
234 self
235 }
236
237 pub fn oauth(mut self) -> Self {
238 self.routes.insert(RouteGroup::OAuth);
239 self
240 }
241
242 pub fn token(mut self) -> Self {
243 self.routes.insert(RouteGroup::Token);
244 self
245 }
246
247 pub fn userinfo(mut self) -> Self {
248 self.routes.insert(RouteGroup::UserInfo);
249 self
250 }
251
252 pub fn well_known(mut self) -> Self {
253 self.routes.insert(RouteGroup::WellKnown);
254 self
255 }
256
257 pub fn all_routes(mut self) -> Self {
258 self.all = true;
259 self
260 }
261
262 pub fn cors(mut self) -> Self {
267 self.cors_enabled = true;
268 self
269 }
270
271 fn selected(&self, group: RouteGroup) -> bool {
274 self.all || self.routes.contains(&group)
275 }
276
277 fn validate(&self) -> Result<(), AllRoutesError> {
278 if !self.all && self.routes.is_empty() {
279 return Err(AllRoutesError::NoRoutesSelected);
280 }
281
282 let needs_base_url = self.selected(RouteGroup::PasswordReset)
283 || self.selected(RouteGroup::Mfa)
284 || self.selected(RouteGroup::OAuth)
285 || self.selected(RouteGroup::WellKnown);
286
287 if needs_base_url && self.base_url.is_none() {
288 return Err(AllRoutesError::MissingConfig(
289 "base_url required by selected route groups".into(),
290 ));
291 }
292
293 if self.events_tx.is_some() && self.base_url.is_none() {
294 return Err(AllRoutesError::MissingConfig(
295 "base_url required when events channel is configured".into(),
296 ));
297 }
298
299 if self.selected(RouteGroup::OAuth) && self.oauth_provider_impls.is_none() {
300 return Err(AllRoutesError::MissingConfig(
301 "oauth_providers required when oauth routes are selected".into(),
302 ));
303 }
304
305 if self.selected(RouteGroup::Mfa) && self.mfa_issuer.is_none() {
306 return Err(AllRoutesError::MissingConfig(
307 "mfa_issuer required when mfa routes are selected".into(),
308 ));
309 }
310
311 Ok(())
312 }
313
314 fn build_inner(mut self) -> Result<Router<()>, AllRoutesError> {
315 self.validate()?;
316
317 let templates = self
320 .templates
321 .take()
322 .unwrap_or_else(build_default_browser_env);
323 let is_production = self.is_production;
324
325 let oauth_providers_list = self
329 .oauth_providers_list
330 .take()
331 .unwrap_or_else(|| match &self.oauth_provider_impls {
332 Some(map) => {
333 let mut names: Vec<String> = map.keys().cloned().collect();
334 names.sort();
335 names
336 }
337 None => Vec::new(),
338 });
339
340 let mut csrf_protected: Router<()> = Router::new();
343
344 if self.selected(RouteGroup::Login) {
345 csrf_protected = csrf_protected.merge(crate::login_routes::login_routes(
346 templates.clone(),
347 is_production,
348 self.max_login_attempts,
349 self.rate_limit_window_secs,
350 oauth_providers_list.clone(),
351 self.login_overrides.take(),
352 ));
353 }
354
355 if self.selected(RouteGroup::Register) {
356 let custom_schema = if let Some(schema) = self.custom_fields_schema.take() {
357 crate::custom_fields::validate_custom_schema(&schema)
358 .map_err(AllRoutesError::InvalidSchema)?;
359 let validator = jsonschema::validator_for(&schema)
360 .map_err(|e| AllRoutesError::InvalidSchema(e.to_string()))?;
361 let fields = crate::custom_fields::extract_field_descriptors(&schema);
362 Some(crate::custom_fields::CustomSchemaConfig {
363 schema,
364 validator,
365 fields,
366 })
367 } else {
368 None
369 };
370 csrf_protected = csrf_protected.merge(crate::register_routes::register_routes(
371 templates.clone(),
372 is_production,
373 custom_schema,
374 self.events_tx.clone(),
375 self.base_url.clone(),
376 oauth_providers_list.clone(),
377 self.public_registration,
378 ));
379 }
380
381 if self.selected(RouteGroup::Logout) {
382 csrf_protected = csrf_protected.merge(crate::logout_routes::logout_routes());
383 }
384
385 if self.selected(RouteGroup::Settings) {
386 csrf_protected = csrf_protected.merge(crate::settings_routes::settings_routes(
387 templates.clone(),
388 is_production,
389 ));
390 }
391
392 if self.selected(RouteGroup::Consent) {
393 csrf_protected = csrf_protected.merge(crate::consent_routes::consent_routes(
394 templates.clone(),
395 is_production,
396 ));
397 }
398
399 if self.selected(RouteGroup::PasswordReset) {
400 csrf_protected = csrf_protected.merge(
401 crate::password_reset_page_routes::password_reset_page_routes(
402 templates.clone(),
403 is_production,
404 ),
405 );
406 }
407
408 if self.selected(RouteGroup::Mfa) {
409 let base_url = self.base_url.clone().expect("validated above");
410 csrf_protected = csrf_protected.merge(crate::mfa_page_routes::mfa_setup_routes(
411 templates.clone(),
412 is_production,
413 base_url,
414 ));
415 }
416
417 let mut non_csrf: Router<()> = Router::new();
420
421 if self.selected(RouteGroup::Mfa) {
422 non_csrf = non_csrf.merge(crate::mfa_page_routes::mfa_challenge_routes(
423 templates.clone(),
424 is_production,
425 ));
426 let issuer = self.mfa_issuer.take().expect("validated above");
427 non_csrf = non_csrf.merge(crate::mfa_routes::mfa_routes(issuer));
428 }
429
430 if self.selected(RouteGroup::OAuth) {
431 let providers = self.oauth_provider_impls.take().expect("validated above");
432 let base_url = self.base_url.clone().expect("validated above");
433 non_csrf = non_csrf.merge(crate::oauth_routes::oauth_routes(
434 providers,
435 base_url,
436 self.events_tx.clone(),
437 ));
438 }
439
440 let mut oidc: Router<()> = Router::new();
443
444 if self.selected(RouteGroup::Token) {
445 oidc = oidc.merge(crate::token_route::token_route());
446 }
447
448 if self.selected(RouteGroup::UserInfo) {
449 oidc = oidc.merge(crate::userinfo_route::userinfo_route());
450 }
451
452 if self.selected(RouteGroup::WellKnown) {
453 let base_url = self.base_url.clone().expect("validated above");
454 oidc = oidc.merge(crate::well_known_routes::well_known_routes(base_url));
455 }
456
457 let oidc_final: Router<()> = if self.cors_enabled {
460 oidc.layer(axum::middleware::from_fn(crate::cors::cors_middleware))
461 } else {
462 oidc
463 };
464
465 non_csrf = non_csrf.merge(oidc_final);
466
467 if self.selected(RouteGroup::PasswordReset) {
468 non_csrf = non_csrf.merge(crate::password_reset_routes::password_reset_routes());
469 }
470
471 let default_branding = self.default_branding.take();
472
473 let mut csrf_protected =
477 csrf_protected.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware));
478
479 if let Some(branding) = default_branding {
480 let default = Arc::new(crate::branding::DefaultBranding(branding));
481 csrf_protected = csrf_protected.layer(axum::Extension(default));
482 }
483
484 Ok(csrf_protected
485 .merge(non_csrf)
486 .merge(crate::static_routes::router()))
487 }
488
489 pub fn build(self, ath: &AllowThem) -> Result<Router<()>, AllRoutesError> {
492 let inner = self.build_inner()?;
493 Ok(inner.layer(axum::middleware::from_fn_with_state(
494 ath.clone(),
495 crate::cors::inject_ath_into_extensions,
496 )))
497 }
498
499 pub fn build_for_saas(self) -> Result<Router<()>, AllRoutesError> {
502 self.build_inner()
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use allowthem_core::AllowThemBuilder;
510
511 #[tokio::test]
512 async fn build_fails_no_routes_selected() {
513 let ath = AllowThemBuilder::new("sqlite::memory:")
514 .build()
515 .await
516 .unwrap();
517 let result = AllRoutesBuilder::new().build(&ath);
518 assert!(matches!(result, Err(AllRoutesError::NoRoutesSelected)));
519 }
520
521 #[tokio::test]
522 async fn build_fails_oauth_without_providers() {
523 let ath = AllowThemBuilder::new("sqlite::memory:")
524 .build()
525 .await
526 .unwrap();
527 let result = AllRoutesBuilder::new()
528 .base_url("http://localhost")
529 .oauth()
530 .build(&ath);
531 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
532 }
533
534 #[tokio::test]
535 async fn build_fails_mfa_without_issuer() {
536 let ath = AllowThemBuilder::new("sqlite::memory:")
537 .build()
538 .await
539 .unwrap();
540 let result = AllRoutesBuilder::new()
541 .base_url("http://localhost")
542 .mfa()
543 .build(&ath);
544 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
545 }
546
547 #[tokio::test]
548 async fn build_fails_events_without_base_url() {
549 let ath = AllowThemBuilder::new("sqlite::memory:")
550 .csrf_key(*b"test-csrf-key-for-server-tests!!")
551 .build()
552 .await
553 .unwrap();
554 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
555 let result = AllRoutesBuilder::new().events(tx).register().build(&ath);
556 match result {
557 Err(AllRoutesError::MissingConfig(msg)) => assert!(msg.contains("base_url")),
558 other => panic!("expected MissingConfig, got {other:?}"),
559 }
560 }
561
562 #[tokio::test]
563 async fn build_fails_missing_base_url() {
564 let ath = AllowThemBuilder::new("sqlite::memory:")
565 .build()
566 .await
567 .unwrap();
568 let result = AllRoutesBuilder::new().well_known().build(&ath);
569 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
570 }
571
572 #[tokio::test]
573 async fn build_with_invalid_schema_returns_error() {
574 let ath = AllowThemBuilder::new("sqlite::memory:")
575 .csrf_key(*b"test-csrf-key-for-server-tests!!")
576 .build()
577 .await
578 .unwrap();
579 let schema = serde_json::json!({"type": "array"});
581 let result = AllRoutesBuilder::new()
582 .register()
583 .custom_fields_schema(schema)
584 .build(&ath);
585 assert!(matches!(result, Err(AllRoutesError::InvalidSchema(_))));
586 }
587
588 #[test]
589 fn default_branding_setter_stores_value() {
590 use allowthem_core::applications::BrandingConfig;
591 let builder = AllRoutesBuilder::new().default_branding(BrandingConfig::new("Fixture Co"));
592 assert_eq!(
593 builder
594 .default_branding
595 .as_ref()
596 .map(|b| b.application_name.as_str()),
597 Some("Fixture Co")
598 );
599 }
600
601 #[test]
602 fn default_branding_absent_by_default() {
603 let builder = AllRoutesBuilder::new();
604 assert!(builder.default_branding.is_none());
605 }
606
607 #[tokio::test]
608 async fn build_succeeds_for_simple_routes() {
609 let ath = AllowThemBuilder::new("sqlite::memory:")
610 .csrf_key(*b"test-csrf-key-for-server-tests!!")
611 .build()
612 .await
613 .unwrap();
614 let result = AllRoutesBuilder::new()
616 .login()
617 .register()
618 .logout()
619 .settings()
620 .consent()
621 .token()
622 .userinfo()
623 .build(&ath);
624 assert!(result.is_ok());
625 }
626}