1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use axum::Router;
6use minijinja::Environment;
7
8use allowthem_core::{AllowThem, AuthEventSender, EmailSender, OAuthProvider};
9
10use crate::browser_templates::build_default_browser_env;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum RouteGroup {
15 Login,
16 Register,
17 Logout,
18 Settings,
19 Consent,
20 PasswordReset,
21 Mfa,
22 OAuth,
23 Token,
24 UserInfo,
25 WellKnown,
26}
27
28#[derive(Debug)]
30pub enum AllRoutesError {
31 NoRoutesSelected,
32 MissingConfig(String),
33 InvalidSchema(String),
34}
35
36impl fmt::Display for AllRoutesError {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 Self::NoRoutesSelected => f.write_str("no route groups selected"),
40 Self::MissingConfig(msg) => write!(f, "missing config: {msg}"),
41 Self::InvalidSchema(msg) => write!(f, "invalid custom fields schema: {msg}"),
42 }
43 }
44}
45
46impl std::error::Error for AllRoutesError {}
47
48pub struct AllRoutesBuilder {
51 templates: Option<Arc<Environment<'static>>>,
53 is_production: bool,
54 base_url: Option<String>,
55 email_sender: Option<Arc<dyn EmailSender>>,
56
57 max_login_attempts: u32,
59 rate_limit_window_secs: u64,
60 oauth_providers_list: Option<Vec<String>>,
61
62 oauth_provider_impls: Option<HashMap<String, Box<dyn OAuthProvider>>>,
64
65 mfa_issuer: Option<String>,
67
68 custom_fields_schema: Option<serde_json::Value>,
70
71 events_tx: Option<AuthEventSender>,
73
74 routes: HashSet<RouteGroup>,
76 all: bool,
77
78 cors_enabled: bool,
80
81 default_branding: Option<allowthem_core::applications::BrandingConfig>,
83}
84
85impl Default for AllRoutesBuilder {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl AllRoutesBuilder {
92 pub fn new() -> Self {
93 Self {
94 templates: None,
95 is_production: false,
96 base_url: None,
97 email_sender: None,
98 max_login_attempts: 10,
99 rate_limit_window_secs: 900,
100 oauth_providers_list: None,
101 oauth_provider_impls: None,
102 mfa_issuer: None,
103 custom_fields_schema: None,
104 events_tx: None,
105 routes: HashSet::new(),
106 all: false,
107 cors_enabled: false,
108 default_branding: None,
109 }
110 }
111
112 pub fn templates(mut self, templates: Arc<Environment<'static>>) -> Self {
115 self.templates = Some(templates);
116 self
117 }
118
119 pub fn is_production(mut self, is_production: bool) -> Self {
120 self.is_production = is_production;
121 self
122 }
123
124 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
125 self.base_url = Some(base_url.into());
126 self
127 }
128
129 pub fn email_sender(mut self, sender: Arc<dyn EmailSender>) -> Self {
130 self.email_sender = Some(sender);
131 self
132 }
133
134 pub fn events(mut self, tx: AuthEventSender) -> Self {
139 self.events_tx = Some(tx);
140 self
141 }
142
143 pub fn default_branding(
148 mut self,
149 branding: allowthem_core::applications::BrandingConfig,
150 ) -> Self {
151 self.default_branding = Some(branding);
152 self
153 }
154
155 pub fn max_login_attempts(mut self, max: u32) -> Self {
158 self.max_login_attempts = max;
159 self
160 }
161
162 pub fn rate_limit_window_secs(mut self, secs: u64) -> Self {
163 self.rate_limit_window_secs = secs;
164 self
165 }
166
167 pub fn oauth_providers_list(mut self, providers: Vec<String>) -> Self {
168 self.oauth_providers_list = Some(providers);
169 self
170 }
171
172 pub fn oauth_providers(mut self, providers: HashMap<String, Box<dyn OAuthProvider>>) -> Self {
175 self.oauth_provider_impls = Some(providers);
176 self
177 }
178
179 pub fn mfa_issuer(mut self, issuer: impl Into<String>) -> Self {
182 self.mfa_issuer = Some(issuer.into());
183 self
184 }
185
186 pub fn custom_fields_schema(mut self, schema: serde_json::Value) -> Self {
189 self.custom_fields_schema = Some(schema);
190 self
191 }
192
193 pub fn login(mut self) -> Self {
196 self.routes.insert(RouteGroup::Login);
197 self
198 }
199
200 pub fn register(mut self) -> Self {
201 self.routes.insert(RouteGroup::Register);
202 self
203 }
204
205 pub fn logout(mut self) -> Self {
206 self.routes.insert(RouteGroup::Logout);
207 self
208 }
209
210 pub fn settings(mut self) -> Self {
211 self.routes.insert(RouteGroup::Settings);
212 self
213 }
214
215 pub fn consent(mut self) -> Self {
216 self.routes.insert(RouteGroup::Consent);
217 self
218 }
219
220 pub fn password_reset(mut self) -> Self {
221 self.routes.insert(RouteGroup::PasswordReset);
222 self
223 }
224
225 pub fn mfa(mut self) -> Self {
226 self.routes.insert(RouteGroup::Mfa);
227 self
228 }
229
230 pub fn oauth(mut self) -> Self {
231 self.routes.insert(RouteGroup::OAuth);
232 self
233 }
234
235 pub fn token(mut self) -> Self {
236 self.routes.insert(RouteGroup::Token);
237 self
238 }
239
240 pub fn userinfo(mut self) -> Self {
241 self.routes.insert(RouteGroup::UserInfo);
242 self
243 }
244
245 pub fn well_known(mut self) -> Self {
246 self.routes.insert(RouteGroup::WellKnown);
247 self
248 }
249
250 pub fn all_routes(mut self) -> Self {
251 self.all = true;
252 self
253 }
254
255 pub fn cors(mut self) -> Self {
260 self.cors_enabled = true;
261 self
262 }
263
264 fn selected(&self, group: RouteGroup) -> bool {
267 self.all || self.routes.contains(&group)
268 }
269
270 fn validate(&self) -> Result<(), AllRoutesError> {
271 if !self.all && self.routes.is_empty() {
272 return Err(AllRoutesError::NoRoutesSelected);
273 }
274
275 let needs_base_url = self.selected(RouteGroup::PasswordReset)
276 || self.selected(RouteGroup::Mfa)
277 || self.selected(RouteGroup::OAuth)
278 || self.selected(RouteGroup::WellKnown);
279
280 if needs_base_url && self.base_url.is_none() {
281 return Err(AllRoutesError::MissingConfig(
282 "base_url required by selected route groups".into(),
283 ));
284 }
285
286 if self.events_tx.is_some() && self.base_url.is_none() {
287 return Err(AllRoutesError::MissingConfig(
288 "base_url required when events channel is configured".into(),
289 ));
290 }
291
292 if self.selected(RouteGroup::OAuth) && self.oauth_provider_impls.is_none() {
293 return Err(AllRoutesError::MissingConfig(
294 "oauth_providers required when oauth routes are selected".into(),
295 ));
296 }
297
298 if self.selected(RouteGroup::PasswordReset) && self.email_sender.is_none() {
299 return Err(AllRoutesError::MissingConfig(
300 "email_sender required when password_reset routes are selected".into(),
301 ));
302 }
303
304 if self.selected(RouteGroup::Mfa) && self.mfa_issuer.is_none() {
305 return Err(AllRoutesError::MissingConfig(
306 "mfa_issuer required when mfa routes are selected".into(),
307 ));
308 }
309
310 Ok(())
311 }
312
313 fn build_inner(mut self) -> Result<Router<()>, AllRoutesError> {
314 self.validate()?;
315
316 let templates = self
319 .templates
320 .take()
321 .unwrap_or_else(build_default_browser_env);
322 let is_production = self.is_production;
323
324 let oauth_providers_list = self
328 .oauth_providers_list
329 .take()
330 .unwrap_or_else(|| match &self.oauth_provider_impls {
331 Some(map) => {
332 let mut names: Vec<String> = map.keys().cloned().collect();
333 names.sort();
334 names
335 }
336 None => Vec::new(),
337 });
338
339 let mut csrf_protected: Router<()> = Router::new();
342
343 if self.selected(RouteGroup::Login) {
344 csrf_protected = csrf_protected.merge(crate::login_routes::login_routes(
345 templates.clone(),
346 is_production,
347 self.max_login_attempts,
348 self.rate_limit_window_secs,
349 oauth_providers_list.clone(),
350 ));
351 }
352
353 if self.selected(RouteGroup::Register) {
354 let custom_schema = if let Some(schema) = self.custom_fields_schema.take() {
355 crate::custom_fields::validate_custom_schema(&schema)
356 .map_err(AllRoutesError::InvalidSchema)?;
357 let validator = jsonschema::validator_for(&schema)
358 .map_err(|e| AllRoutesError::InvalidSchema(e.to_string()))?;
359 let fields = crate::custom_fields::extract_field_descriptors(&schema);
360 Some(crate::custom_fields::CustomSchemaConfig {
361 schema,
362 validator,
363 fields,
364 })
365 } else {
366 None
367 };
368 csrf_protected = csrf_protected.merge(crate::register_routes::register_routes(
369 templates.clone(),
370 is_production,
371 custom_schema,
372 self.events_tx.clone(),
373 self.base_url.clone(),
374 oauth_providers_list.clone(),
375 ));
376 }
377
378 if self.selected(RouteGroup::Logout) {
379 csrf_protected = csrf_protected.merge(crate::logout_routes::logout_routes());
380 }
381
382 if self.selected(RouteGroup::Settings) {
383 csrf_protected = csrf_protected.merge(crate::settings_routes::settings_routes(
384 templates.clone(),
385 is_production,
386 ));
387 }
388
389 if self.selected(RouteGroup::Consent) {
390 csrf_protected = csrf_protected.merge(crate::consent_routes::consent_routes(
391 templates.clone(),
392 is_production,
393 ));
394 }
395
396 if self.selected(RouteGroup::PasswordReset) {
397 let email_sender = self.email_sender.clone().expect("validated above");
398 let base_url = self.base_url.clone().expect("validated above");
399 csrf_protected = csrf_protected.merge(
400 crate::password_reset_page_routes::password_reset_page_routes(
401 templates.clone(),
402 is_production,
403 email_sender,
404 base_url,
405 ),
406 );
407 }
408
409 if self.selected(RouteGroup::Mfa) {
410 let base_url = self.base_url.clone().expect("validated above");
411 csrf_protected = csrf_protected.merge(crate::mfa_page_routes::mfa_setup_routes(
412 templates.clone(),
413 is_production,
414 base_url,
415 ));
416 }
417
418 let mut non_csrf: Router<()> = Router::new();
421
422 if self.selected(RouteGroup::Mfa) {
423 non_csrf = non_csrf.merge(crate::mfa_page_routes::mfa_challenge_routes(
424 templates.clone(),
425 is_production,
426 ));
427 let issuer = self.mfa_issuer.take().expect("validated above");
428 non_csrf = non_csrf.merge(crate::mfa_routes::mfa_routes(issuer));
429 }
430
431 if self.selected(RouteGroup::OAuth) {
432 let providers = self.oauth_provider_impls.take().expect("validated above");
433 let base_url = self.base_url.clone().expect("validated above");
434 non_csrf = non_csrf.merge(crate::oauth_routes::oauth_routes(
435 providers,
436 base_url,
437 self.events_tx.clone(),
438 ));
439 }
440
441 let mut oidc: Router<()> = Router::new();
444
445 if self.selected(RouteGroup::Token) {
446 oidc = oidc.merge(crate::token_route::token_route());
447 }
448
449 if self.selected(RouteGroup::UserInfo) {
450 oidc = oidc.merge(crate::userinfo_route::userinfo_route());
451 }
452
453 if self.selected(RouteGroup::WellKnown) {
454 let base_url = self.base_url.clone().expect("validated above");
455 oidc = oidc.merge(crate::well_known_routes::well_known_routes(base_url));
456 }
457
458 let oidc_final: Router<()> = if self.cors_enabled {
461 oidc.layer(axum::middleware::from_fn(crate::cors::cors_middleware))
462 } else {
463 oidc
464 };
465
466 non_csrf = non_csrf.merge(oidc_final);
467
468 if self.selected(RouteGroup::PasswordReset) {
469 let email_sender = self.email_sender.take().expect("validated above");
470 let base_url = self.base_url.expect("validated above");
471 non_csrf = non_csrf.merge(crate::password_reset_routes::password_reset_routes(
472 email_sender,
473 base_url,
474 ));
475 }
476
477 let default_branding = self.default_branding.take();
478
479 let mut csrf_protected =
483 csrf_protected.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware));
484
485 if let Some(branding) = default_branding {
486 let default = Arc::new(crate::branding::DefaultBranding(branding));
487 csrf_protected = csrf_protected.layer(axum::Extension(default));
488 }
489
490 Ok(csrf_protected
491 .merge(non_csrf)
492 .merge(crate::static_routes::router()))
493 }
494
495 pub fn build(self, ath: &AllowThem) -> Result<Router<()>, AllRoutesError> {
498 let inner = self.build_inner()?;
499 Ok(inner.layer(axum::middleware::from_fn_with_state(
500 ath.clone(),
501 crate::cors::inject_ath_into_extensions,
502 )))
503 }
504
505 pub fn build_for_saas(self) -> Result<Router<()>, AllRoutesError> {
508 self.build_inner()
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use allowthem_core::AllowThemBuilder;
516
517 #[tokio::test]
518 async fn build_fails_no_routes_selected() {
519 let ath = AllowThemBuilder::new("sqlite::memory:")
520 .build()
521 .await
522 .unwrap();
523 let result = AllRoutesBuilder::new().build(&ath);
524 assert!(matches!(result, Err(AllRoutesError::NoRoutesSelected)));
525 }
526
527 #[tokio::test]
528 async fn build_fails_oauth_without_providers() {
529 let ath = AllowThemBuilder::new("sqlite::memory:")
530 .build()
531 .await
532 .unwrap();
533 let result = AllRoutesBuilder::new()
534 .base_url("http://localhost")
535 .oauth()
536 .build(&ath);
537 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
538 }
539
540 #[tokio::test]
541 async fn build_fails_password_reset_without_email_sender() {
542 let ath = AllowThemBuilder::new("sqlite::memory:")
543 .build()
544 .await
545 .unwrap();
546 let result = AllRoutesBuilder::new()
547 .base_url("http://localhost")
548 .password_reset()
549 .build(&ath);
550 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
551 }
552
553 #[tokio::test]
554 async fn build_fails_mfa_without_issuer() {
555 let ath = AllowThemBuilder::new("sqlite::memory:")
556 .build()
557 .await
558 .unwrap();
559 let result = AllRoutesBuilder::new()
560 .base_url("http://localhost")
561 .mfa()
562 .build(&ath);
563 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
564 }
565
566 #[tokio::test]
567 async fn build_fails_events_without_base_url() {
568 let ath = AllowThemBuilder::new("sqlite::memory:")
569 .csrf_key(*b"test-csrf-key-for-server-tests!!")
570 .build()
571 .await
572 .unwrap();
573 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
574 let result = AllRoutesBuilder::new().events(tx).register().build(&ath);
575 match result {
576 Err(AllRoutesError::MissingConfig(msg)) => assert!(msg.contains("base_url")),
577 other => panic!("expected MissingConfig, got {other:?}"),
578 }
579 }
580
581 #[tokio::test]
582 async fn build_fails_missing_base_url() {
583 let ath = AllowThemBuilder::new("sqlite::memory:")
584 .build()
585 .await
586 .unwrap();
587 let result = AllRoutesBuilder::new().well_known().build(&ath);
588 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
589 }
590
591 #[tokio::test]
592 async fn build_with_invalid_schema_returns_error() {
593 let ath = AllowThemBuilder::new("sqlite::memory:")
594 .csrf_key(*b"test-csrf-key-for-server-tests!!")
595 .build()
596 .await
597 .unwrap();
598 let schema = serde_json::json!({"type": "array"});
600 let result = AllRoutesBuilder::new()
601 .register()
602 .custom_fields_schema(schema)
603 .build(&ath);
604 assert!(matches!(result, Err(AllRoutesError::InvalidSchema(_))));
605 }
606
607 #[test]
608 fn default_branding_setter_stores_value() {
609 use allowthem_core::applications::BrandingConfig;
610 let builder = AllRoutesBuilder::new().default_branding(BrandingConfig::new("Fixture Co"));
611 assert_eq!(
612 builder
613 .default_branding
614 .as_ref()
615 .map(|b| b.application_name.as_str()),
616 Some("Fixture Co")
617 );
618 }
619
620 #[test]
621 fn default_branding_absent_by_default() {
622 let builder = AllRoutesBuilder::new();
623 assert!(builder.default_branding.is_none());
624 }
625
626 #[tokio::test]
627 async fn build_succeeds_for_simple_routes() {
628 let ath = AllowThemBuilder::new("sqlite::memory:")
629 .csrf_key(*b"test-csrf-key-for-server-tests!!")
630 .build()
631 .await
632 .unwrap();
633 let result = AllRoutesBuilder::new()
635 .login()
636 .register()
637 .logout()
638 .settings()
639 .consent()
640 .token()
641 .userinfo()
642 .build(&ath);
643 assert!(result.is_ok());
644 }
645}