1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use axum::Router;
6
7use allowthem_core::{AllowThem, LifecycleEventSender, OAuthProvider};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum RouteGroup {
12 Login,
13 Register,
14 Logout,
15 Settings,
16 Consent,
17 PasswordReset,
18 Mfa,
19 OAuth,
20 Token,
21 UserInfo,
22 WellKnown,
23}
24
25#[derive(Debug)]
27pub enum AllRoutesError {
28 NoRoutesSelected,
29 MissingConfig(String),
30 InvalidSchema(String),
31}
32
33impl fmt::Display for AllRoutesError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 Self::NoRoutesSelected => f.write_str("no route groups selected"),
37 Self::MissingConfig(msg) => write!(f, "missing config: {msg}"),
38 Self::InvalidSchema(msg) => write!(f, "invalid custom fields schema: {msg}"),
39 }
40 }
41}
42
43impl std::error::Error for AllRoutesError {}
44
45pub struct AllRoutesBuilder {
48 is_production: bool,
50 base_url: Option<String>,
51
52 max_login_attempts: u32,
54 rate_limit_window_secs: u64,
55 oauth_providers_list: Option<Vec<String>>,
56 login_overrides: Option<crate::login_routes::LoginOverrides>,
57
58 oauth_provider_impls: Option<HashMap<String, Box<dyn OAuthProvider>>>,
60
61 mfa_issuer: Option<String>,
63
64 custom_fields_schema: Option<serde_json::Value>,
66 public_registration: bool,
67
68 events_tx: Option<LifecycleEventSender>,
70
71 routes: HashSet<RouteGroup>,
73 all: bool,
74
75 cors_enabled: bool,
77
78 default_branding: Option<allowthem_core::applications::BrandingConfig>,
80}
81
82impl Default for AllRoutesBuilder {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl AllRoutesBuilder {
89 pub fn new() -> Self {
90 Self {
91 is_production: false,
92 base_url: None,
93 max_login_attempts: 10,
94 rate_limit_window_secs: 900,
95 oauth_providers_list: None,
96 login_overrides: None,
97 oauth_provider_impls: None,
98 mfa_issuer: None,
99 custom_fields_schema: None,
100 public_registration: true,
101 events_tx: None,
102 routes: HashSet::new(),
103 all: false,
104 cors_enabled: false,
105 default_branding: None,
106 }
107 }
108
109 pub fn is_production(mut self, is_production: bool) -> Self {
112 self.is_production = is_production;
113 self
114 }
115
116 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
117 self.base_url = Some(base_url.into());
118 self
119 }
120
121 pub fn events(mut self, tx: LifecycleEventSender) -> Self {
126 self.events_tx = Some(tx);
127 self
128 }
129
130 pub fn default_branding(
135 mut self,
136 branding: allowthem_core::applications::BrandingConfig,
137 ) -> Self {
138 self.default_branding = Some(branding);
139 self
140 }
141
142 pub fn max_login_attempts(mut self, max: u32) -> Self {
145 self.max_login_attempts = max;
146 self
147 }
148
149 pub fn rate_limit_window_secs(mut self, secs: u64) -> Self {
150 self.rate_limit_window_secs = secs;
151 self
152 }
153
154 pub fn oauth_providers_list(mut self, providers: Vec<String>) -> Self {
155 self.oauth_providers_list = Some(providers);
156 self
157 }
158
159 pub fn login_overrides(mut self, overrides: crate::login_routes::LoginOverrides) -> Self {
160 self.login_overrides = Some(overrides);
161 self
162 }
163
164 pub fn oauth_providers(mut self, providers: HashMap<String, Box<dyn OAuthProvider>>) -> Self {
167 self.oauth_provider_impls = Some(providers);
168 self
169 }
170
171 pub fn mfa_issuer(mut self, issuer: impl Into<String>) -> Self {
174 self.mfa_issuer = Some(issuer.into());
175 self
176 }
177
178 pub fn custom_fields_schema(mut self, schema: serde_json::Value) -> Self {
181 self.custom_fields_schema = Some(schema);
182 self
183 }
184
185 pub fn public_registration(mut self, enabled: bool) -> Self {
186 self.public_registration = enabled;
187 self
188 }
189
190 pub fn login(mut self) -> Self {
193 self.routes.insert(RouteGroup::Login);
194 self
195 }
196
197 pub fn register(mut self) -> Self {
198 self.routes.insert(RouteGroup::Register);
199 self
200 }
201
202 pub fn logout(mut self) -> Self {
203 self.routes.insert(RouteGroup::Logout);
204 self
205 }
206
207 pub fn settings(mut self) -> Self {
208 self.routes.insert(RouteGroup::Settings);
209 self
210 }
211
212 pub fn consent(mut self) -> Self {
213 self.routes.insert(RouteGroup::Consent);
214 self
215 }
216
217 pub fn password_reset(mut self) -> Self {
218 self.routes.insert(RouteGroup::PasswordReset);
219 self
220 }
221
222 pub fn mfa(mut self) -> Self {
223 self.routes.insert(RouteGroup::Mfa);
224 self
225 }
226
227 pub fn oauth(mut self) -> Self {
228 self.routes.insert(RouteGroup::OAuth);
229 self
230 }
231
232 pub fn token(mut self) -> Self {
233 self.routes.insert(RouteGroup::Token);
234 self
235 }
236
237 pub fn userinfo(mut self) -> Self {
238 self.routes.insert(RouteGroup::UserInfo);
239 self
240 }
241
242 pub fn well_known(mut self) -> Self {
243 self.routes.insert(RouteGroup::WellKnown);
244 self
245 }
246
247 pub fn all_routes(mut self) -> Self {
248 self.all = true;
249 self
250 }
251
252 pub fn cors(mut self) -> Self {
257 self.cors_enabled = true;
258 self
259 }
260
261 fn selected(&self, group: RouteGroup) -> bool {
264 self.all || self.routes.contains(&group)
265 }
266
267 fn validate(&self) -> Result<(), AllRoutesError> {
268 if !self.all && self.routes.is_empty() {
269 return Err(AllRoutesError::NoRoutesSelected);
270 }
271
272 let needs_base_url = self.selected(RouteGroup::PasswordReset)
273 || self.selected(RouteGroup::Mfa)
274 || self.selected(RouteGroup::OAuth)
275 || self.selected(RouteGroup::WellKnown);
276
277 if needs_base_url && self.base_url.is_none() {
278 return Err(AllRoutesError::MissingConfig(
279 "base_url required by selected route groups".into(),
280 ));
281 }
282
283 if self.events_tx.is_some() && self.base_url.is_none() {
284 return Err(AllRoutesError::MissingConfig(
285 "base_url required when events channel is configured".into(),
286 ));
287 }
288
289 if self.selected(RouteGroup::OAuth) && self.oauth_provider_impls.is_none() {
290 return Err(AllRoutesError::MissingConfig(
291 "oauth_providers required when oauth routes are selected".into(),
292 ));
293 }
294
295 if self.selected(RouteGroup::Mfa) && self.mfa_issuer.is_none() {
296 return Err(AllRoutesError::MissingConfig(
297 "mfa_issuer required when mfa routes are selected".into(),
298 ));
299 }
300
301 Ok(())
302 }
303
304 fn build_inner(mut self) -> Result<Router<()>, AllRoutesError> {
305 self.validate()?;
306
307 let is_production = self.is_production;
310
311 let oauth_providers_list = self
315 .oauth_providers_list
316 .take()
317 .unwrap_or_else(|| match &self.oauth_provider_impls {
318 Some(map) => {
319 let mut names: Vec<String> = map.keys().cloned().collect();
320 names.sort();
321 names
322 }
323 None => Vec::new(),
324 });
325
326 let mut csrf_protected: Router<()> = Router::new();
329
330 if self.selected(RouteGroup::Login) {
331 csrf_protected = csrf_protected.merge(crate::login_routes::login_routes(
332 is_production,
333 self.max_login_attempts,
334 self.rate_limit_window_secs,
335 oauth_providers_list.clone(),
336 self.login_overrides.take(),
337 ));
338 }
339
340 if self.selected(RouteGroup::Register) {
341 let custom_schema = if let Some(schema) = self.custom_fields_schema.take() {
342 crate::custom_fields::validate_custom_schema(&schema)
343 .map_err(AllRoutesError::InvalidSchema)?;
344 let validator = jsonschema::validator_for(&schema)
345 .map_err(|e| AllRoutesError::InvalidSchema(e.to_string()))?;
346 let fields = crate::custom_fields::extract_field_descriptors(&schema);
347 Some(crate::custom_fields::CustomSchemaConfig {
348 schema,
349 validator,
350 fields,
351 })
352 } else {
353 None
354 };
355 csrf_protected = csrf_protected.merge(crate::register_routes::register_routes(
356 is_production,
357 custom_schema,
358 self.events_tx.clone(),
359 self.base_url.clone(),
360 oauth_providers_list.clone(),
361 self.public_registration,
362 ));
363 }
364
365 if self.selected(RouteGroup::Logout) {
366 csrf_protected = csrf_protected.merge(crate::logout_routes::logout_routes());
367 }
368
369 if self.selected(RouteGroup::Settings) {
370 csrf_protected =
371 csrf_protected.merge(crate::settings_routes::settings_routes(is_production));
372 }
373
374 if self.selected(RouteGroup::Consent) {
375 csrf_protected =
376 csrf_protected.merge(crate::consent_routes::consent_routes(is_production));
377 }
378
379 if self.selected(RouteGroup::PasswordReset) {
380 csrf_protected = csrf_protected.merge(
381 crate::password_reset_page_routes::password_reset_page_routes(is_production),
382 );
383 }
384
385 if self.selected(RouteGroup::Mfa) {
386 let base_url = self.base_url.clone().expect("validated above");
387 csrf_protected = csrf_protected.merge(crate::mfa_page_routes::mfa_setup_routes(
388 is_production,
389 base_url,
390 ));
391 }
392
393 let mut non_csrf: Router<()> = Router::new();
396
397 if self.selected(RouteGroup::Mfa) {
398 non_csrf = non_csrf.merge(crate::mfa_page_routes::mfa_challenge_routes(is_production));
399 let issuer = self.mfa_issuer.take().expect("validated above");
400 non_csrf = non_csrf.merge(crate::mfa_routes::mfa_routes(issuer));
401 }
402
403 if self.selected(RouteGroup::OAuth) {
404 let providers = self.oauth_provider_impls.take().expect("validated above");
405 let base_url = self.base_url.clone().expect("validated above");
406 non_csrf = non_csrf.merge(crate::oauth_routes::oauth_routes(
407 providers,
408 base_url,
409 self.events_tx.clone(),
410 ));
411 }
412
413 let mut oidc: Router<()> = Router::new();
416
417 if self.selected(RouteGroup::Token) {
418 oidc = oidc.merge(crate::token_route::token_route());
419 }
420
421 if self.selected(RouteGroup::UserInfo) {
422 oidc = oidc.merge(crate::userinfo_route::userinfo_route());
423 }
424
425 if self.selected(RouteGroup::WellKnown) {
426 let base_url = self.base_url.clone().expect("validated above");
427 oidc = oidc.merge(crate::well_known_routes::well_known_routes(base_url));
428 }
429
430 let oidc_final: Router<()> = if self.cors_enabled {
433 oidc.layer(axum::middleware::from_fn(crate::cors::cors_middleware))
434 } else {
435 oidc
436 };
437
438 non_csrf = non_csrf.merge(oidc_final);
439
440 if self.selected(RouteGroup::PasswordReset) {
441 non_csrf = non_csrf.merge(crate::password_reset_routes::password_reset_routes());
442 }
443
444 let default_branding = self.default_branding.take();
445
446 let mut csrf_protected =
450 csrf_protected.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware));
451
452 if let Some(branding) = default_branding {
453 let default = Arc::new(crate::branding::DefaultBranding(branding));
454 csrf_protected = csrf_protected.layer(axum::Extension(default));
455 }
456
457 Ok(csrf_protected
458 .merge(non_csrf)
459 .nest("/static/wavefunk", wavefunk_ui::axum::asset_router())
460 .merge(crate::static_routes::router()))
461 }
462
463 pub fn build(self, ath: &AllowThem) -> Result<Router<()>, AllRoutesError> {
466 let inner = self.build_inner()?;
467 Ok(inner.layer(axum::middleware::from_fn_with_state(
468 ath.clone(),
469 crate::cors::inject_ath_into_extensions,
470 )))
471 }
472
473 pub fn build_for_saas(self) -> Result<Router<()>, AllRoutesError> {
476 self.build_inner()
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use allowthem_core::AllowThemBuilder;
484 use axum::body::{self, Body};
485 use axum::http::{Request, StatusCode};
486 use tower::ServiceExt;
487
488 #[tokio::test]
489 async fn build_fails_no_routes_selected() {
490 let ath = AllowThemBuilder::new("sqlite::memory:")
491 .build()
492 .await
493 .unwrap();
494 let result = AllRoutesBuilder::new().build(&ath);
495 assert!(matches!(result, Err(AllRoutesError::NoRoutesSelected)));
496 }
497
498 #[tokio::test]
499 async fn build_fails_oauth_without_providers() {
500 let ath = AllowThemBuilder::new("sqlite::memory:")
501 .build()
502 .await
503 .unwrap();
504 let result = AllRoutesBuilder::new()
505 .base_url("http://localhost")
506 .oauth()
507 .build(&ath);
508 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
509 }
510
511 #[tokio::test]
512 async fn build_fails_mfa_without_issuer() {
513 let ath = AllowThemBuilder::new("sqlite::memory:")
514 .build()
515 .await
516 .unwrap();
517 let result = AllRoutesBuilder::new()
518 .base_url("http://localhost")
519 .mfa()
520 .build(&ath);
521 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
522 }
523
524 #[tokio::test]
525 async fn build_fails_events_without_base_url() {
526 let ath = AllowThemBuilder::new("sqlite::memory:")
527 .csrf_key(*b"test-csrf-key-for-server-tests!!")
528 .build()
529 .await
530 .unwrap();
531 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
532 let result = AllRoutesBuilder::new().events(tx).register().build(&ath);
533 match result {
534 Err(AllRoutesError::MissingConfig(msg)) => assert!(msg.contains("base_url")),
535 other => panic!("expected MissingConfig, got {other:?}"),
536 }
537 }
538
539 #[tokio::test]
540 async fn build_fails_missing_base_url() {
541 let ath = AllowThemBuilder::new("sqlite::memory:")
542 .build()
543 .await
544 .unwrap();
545 let result = AllRoutesBuilder::new().well_known().build(&ath);
546 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
547 }
548
549 #[tokio::test]
550 async fn build_with_invalid_schema_returns_error() {
551 let ath = AllowThemBuilder::new("sqlite::memory:")
552 .csrf_key(*b"test-csrf-key-for-server-tests!!")
553 .build()
554 .await
555 .unwrap();
556 let schema = serde_json::json!({"type": "array"});
558 let result = AllRoutesBuilder::new()
559 .register()
560 .custom_fields_schema(schema)
561 .build(&ath);
562 assert!(matches!(result, Err(AllRoutesError::InvalidSchema(_))));
563 }
564
565 #[test]
566 fn default_branding_setter_stores_value() {
567 use allowthem_core::applications::BrandingConfig;
568 let builder = AllRoutesBuilder::new().default_branding(BrandingConfig::new("Fixture Co"));
569 assert_eq!(
570 builder
571 .default_branding
572 .as_ref()
573 .map(|b| b.application_name.as_str()),
574 Some("Fixture Co")
575 );
576 }
577
578 #[test]
579 fn default_branding_absent_by_default() {
580 let builder = AllRoutesBuilder::new();
581 assert!(builder.default_branding.is_none());
582 }
583
584 #[tokio::test]
585 async fn build_succeeds_for_simple_routes() {
586 let ath = AllowThemBuilder::new("sqlite::memory:")
587 .csrf_key(*b"test-csrf-key-for-server-tests!!")
588 .build()
589 .await
590 .unwrap();
591 let result = AllRoutesBuilder::new()
593 .login()
594 .register()
595 .logout()
596 .settings()
597 .consent()
598 .token()
599 .userinfo()
600 .build(&ath);
601 assert!(result.is_ok());
602 }
603
604 #[tokio::test]
605 async fn built_routes_serve_wavefunk_ui_assets() {
606 let ath = AllowThemBuilder::new("sqlite::memory:")
607 .csrf_key(*b"test-csrf-key-for-server-tests!!")
608 .build()
609 .await
610 .unwrap();
611 let app = AllRoutesBuilder::new()
612 .login()
613 .build(&ath)
614 .expect("routes should build");
615
616 for (path, content_type, expected_text) in [
617 (
618 "/static/wavefunk/css/wavefunk.css",
619 "text/css; charset=utf-8",
620 "@import url(\"./04-components.css\")",
621 ),
622 (
623 "/static/wavefunk/js/wavefunk.js",
624 "text/javascript; charset=utf-8",
625 "data-wf-copy",
626 ),
627 ] {
628 let response = app
629 .clone()
630 .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
631 .await
632 .unwrap();
633 assert_eq!(response.status(), StatusCode::OK, "{path} should be served");
634 assert_eq!(
635 response.headers().get("content-type").unwrap(),
636 content_type
637 );
638 let body = body::to_bytes(response.into_body(), usize::MAX)
639 .await
640 .unwrap();
641 let text = std::str::from_utf8(&body).expect("asset should be UTF-8");
642 assert!(
643 text.contains(expected_text),
644 "{path} missing expected marker {expected_text}"
645 );
646 }
647 }
648}