Skip to main content

allowthem_server/
all_routes.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use axum::Router;
6
7use allowthem_core::{AllowThem, LifecycleEventSender, OAuthProvider};
8
9/// Identifies a logical group of routes that can be selectively enabled.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum RouteGroup {
12    Login,
13    Register,
14    Logout,
15    Settings,
16    Consent,
17    PasswordReset,
18    Mfa,
19    OAuth,
20    Token,
21    UserInfo,
22    WellKnown,
23}
24
25/// Errors returned by [`AllRoutesBuilder::build`].
26#[derive(Debug)]
27pub enum AllRoutesError {
28    NoRoutesSelected,
29    MissingConfig(String),
30    InvalidSchema(String),
31}
32
33impl fmt::Display for AllRoutesError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            Self::NoRoutesSelected => f.write_str("no route groups selected"),
37            Self::MissingConfig(msg) => write!(f, "missing config: {msg}"),
38            Self::InvalidSchema(msg) => write!(f, "invalid custom fields schema: {msg}"),
39        }
40    }
41}
42
43impl std::error::Error for AllRoutesError {}
44
45/// Builder that assembles allowthem route groups into a single
46/// [`Router<()>`] with CSRF middleware applied to the correct subset.
47pub struct AllRoutesBuilder {
48    // Shared config
49    is_production: bool,
50    base_url: Option<String>,
51
52    // Login-specific
53    max_login_attempts: u32,
54    rate_limit_window_secs: u64,
55    oauth_providers_list: Option<Vec<String>>,
56    login_overrides: Option<crate::login_routes::LoginOverrides>,
57
58    // OAuth-specific
59    oauth_provider_impls: Option<HashMap<String, Box<dyn OAuthProvider>>>,
60
61    // MFA-specific
62    mfa_issuer: Option<String>,
63
64    // Register-specific
65    custom_fields_schema: Option<serde_json::Value>,
66    public_registration: bool,
67
68    // Event publishing (optional)
69    events_tx: Option<LifecycleEventSender>,
70
71    // Route selection
72    routes: HashSet<RouteGroup>,
73    all: bool,
74
75    // CORS for OIDC endpoints
76    cors_enabled: bool,
77
78    // Embedder fallback branding
79    default_branding: Option<allowthem_core::applications::BrandingConfig>,
80}
81
82impl Default for AllRoutesBuilder {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl AllRoutesBuilder {
89    pub fn new() -> Self {
90        Self {
91            is_production: false,
92            base_url: None,
93            max_login_attempts: 10,
94            rate_limit_window_secs: 900,
95            oauth_providers_list: None,
96            login_overrides: None,
97            oauth_provider_impls: None,
98            mfa_issuer: None,
99            custom_fields_schema: None,
100            public_registration: true,
101            events_tx: None,
102            routes: HashSet::new(),
103            all: false,
104            cors_enabled: false,
105            default_branding: None,
106        }
107    }
108
109    // --- Shared config ---
110
111    pub fn is_production(mut self, is_production: bool) -> Self {
112        self.is_production = is_production;
113        self
114    }
115
116    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
117        self.base_url = Some(base_url.into());
118        self
119    }
120
121    /// Attach a channel sender that receives lifecycle events (register, etc.).
122    ///
123    /// See `docs/superpowers/specs/2026-04-20-lifecycle-events-design.md` for
124    /// the delivery contract. Called at most once; subsequent calls overwrite.
125    pub fn events(mut self, tx: LifecycleEventSender) -> Self {
126        self.events_tx = Some(tx);
127        self
128    }
129
130    /// Set a fallback branding used for every pre-auth page when no per-client
131    /// lookup matches (no `client_id` query param, or no row with that id).
132    /// Embedders typically call this once at startup with their own app name,
133    /// accent, and splash metadata.
134    pub fn default_branding(
135        mut self,
136        branding: allowthem_core::applications::BrandingConfig,
137    ) -> Self {
138        self.default_branding = Some(branding);
139        self
140    }
141
142    // --- Login config ---
143
144    pub fn max_login_attempts(mut self, max: u32) -> Self {
145        self.max_login_attempts = max;
146        self
147    }
148
149    pub fn rate_limit_window_secs(mut self, secs: u64) -> Self {
150        self.rate_limit_window_secs = secs;
151        self
152    }
153
154    pub fn oauth_providers_list(mut self, providers: Vec<String>) -> Self {
155        self.oauth_providers_list = Some(providers);
156        self
157    }
158
159    pub fn login_overrides(mut self, overrides: crate::login_routes::LoginOverrides) -> Self {
160        self.login_overrides = Some(overrides);
161        self
162    }
163
164    // --- OAuth config ---
165
166    pub fn oauth_providers(mut self, providers: HashMap<String, Box<dyn OAuthProvider>>) -> Self {
167        self.oauth_provider_impls = Some(providers);
168        self
169    }
170
171    // --- MFA config ---
172
173    pub fn mfa_issuer(mut self, issuer: impl Into<String>) -> Self {
174        self.mfa_issuer = Some(issuer.into());
175        self
176    }
177
178    // --- Register config ---
179
180    pub fn custom_fields_schema(mut self, schema: serde_json::Value) -> Self {
181        self.custom_fields_schema = Some(schema);
182        self
183    }
184
185    pub fn public_registration(mut self, enabled: bool) -> Self {
186        self.public_registration = enabled;
187        self
188    }
189
190    // --- Route selectors ---
191
192    pub fn login(mut self) -> Self {
193        self.routes.insert(RouteGroup::Login);
194        self
195    }
196
197    pub fn register(mut self) -> Self {
198        self.routes.insert(RouteGroup::Register);
199        self
200    }
201
202    pub fn logout(mut self) -> Self {
203        self.routes.insert(RouteGroup::Logout);
204        self
205    }
206
207    pub fn settings(mut self) -> Self {
208        self.routes.insert(RouteGroup::Settings);
209        self
210    }
211
212    pub fn consent(mut self) -> Self {
213        self.routes.insert(RouteGroup::Consent);
214        self
215    }
216
217    pub fn password_reset(mut self) -> Self {
218        self.routes.insert(RouteGroup::PasswordReset);
219        self
220    }
221
222    pub fn mfa(mut self) -> Self {
223        self.routes.insert(RouteGroup::Mfa);
224        self
225    }
226
227    pub fn oauth(mut self) -> Self {
228        self.routes.insert(RouteGroup::OAuth);
229        self
230    }
231
232    pub fn token(mut self) -> Self {
233        self.routes.insert(RouteGroup::Token);
234        self
235    }
236
237    pub fn userinfo(mut self) -> Self {
238        self.routes.insert(RouteGroup::UserInfo);
239        self
240    }
241
242    pub fn well_known(mut self) -> Self {
243        self.routes.insert(RouteGroup::WellKnown);
244        self
245    }
246
247    pub fn all_routes(mut self) -> Self {
248        self.all = true;
249        self
250    }
251
252    /// Enable dynamic CORS for OIDC endpoints (`/oauth/token`, `/oauth/userinfo`,
253    /// `/.well-known/*`). Allowed origins are derived per-request from active
254    /// application redirect URIs. Has no effect unless Token, UserInfo, or
255    /// WellKnown routes are also selected.
256    pub fn cors(mut self) -> Self {
257        self.cors_enabled = true;
258        self
259    }
260
261    // --- Build ---
262
263    fn selected(&self, group: RouteGroup) -> bool {
264        self.all || self.routes.contains(&group)
265    }
266
267    fn validate(&self) -> Result<(), AllRoutesError> {
268        if !self.all && self.routes.is_empty() {
269            return Err(AllRoutesError::NoRoutesSelected);
270        }
271
272        let needs_base_url = self.selected(RouteGroup::PasswordReset)
273            || self.selected(RouteGroup::Mfa)
274            || self.selected(RouteGroup::OAuth)
275            || self.selected(RouteGroup::WellKnown);
276
277        if needs_base_url && self.base_url.is_none() {
278            return Err(AllRoutesError::MissingConfig(
279                "base_url required by selected route groups".into(),
280            ));
281        }
282
283        if self.events_tx.is_some() && self.base_url.is_none() {
284            return Err(AllRoutesError::MissingConfig(
285                "base_url required when events channel is configured".into(),
286            ));
287        }
288
289        if self.selected(RouteGroup::OAuth) && self.oauth_provider_impls.is_none() {
290            return Err(AllRoutesError::MissingConfig(
291                "oauth_providers required when oauth routes are selected".into(),
292            ));
293        }
294
295        if self.selected(RouteGroup::Mfa) && self.mfa_issuer.is_none() {
296            return Err(AllRoutesError::MissingConfig(
297                "mfa_issuer required when mfa routes are selected".into(),
298            ));
299        }
300
301        Ok(())
302    }
303
304    fn build_inner(mut self) -> Result<Router<()>, AllRoutesError> {
305        self.validate()?;
306
307        // --- Resolve defaults ---
308
309        let is_production = self.is_production;
310
311        // Derive oauth_providers_list from the provider map keys when not
312        // explicitly set. This avoids requiring the caller to duplicate the
313        // provider names.
314        let oauth_providers_list = self
315            .oauth_providers_list
316            .take()
317            .unwrap_or_else(|| match &self.oauth_provider_impls {
318                Some(map) => {
319                    let mut names: Vec<String> = map.keys().cloned().collect();
320                    names.sort();
321                    names
322                }
323                None => Vec::new(),
324            });
325
326        // --- CSRF-protected routes (browser routes) ---
327
328        let mut csrf_protected: Router<()> = Router::new();
329
330        if self.selected(RouteGroup::Login) {
331            csrf_protected = csrf_protected.merge(crate::login_routes::login_routes(
332                is_production,
333                self.max_login_attempts,
334                self.rate_limit_window_secs,
335                oauth_providers_list.clone(),
336                self.login_overrides.take(),
337            ));
338        }
339
340        if self.selected(RouteGroup::Register) {
341            let custom_schema = if let Some(schema) = self.custom_fields_schema.take() {
342                crate::custom_fields::validate_custom_schema(&schema)
343                    .map_err(AllRoutesError::InvalidSchema)?;
344                let validator = jsonschema::validator_for(&schema)
345                    .map_err(|e| AllRoutesError::InvalidSchema(e.to_string()))?;
346                let fields = crate::custom_fields::extract_field_descriptors(&schema);
347                Some(crate::custom_fields::CustomSchemaConfig {
348                    schema,
349                    validator,
350                    fields,
351                })
352            } else {
353                None
354            };
355            csrf_protected = csrf_protected.merge(crate::register_routes::register_routes(
356                is_production,
357                custom_schema,
358                self.events_tx.clone(),
359                self.base_url.clone(),
360                oauth_providers_list.clone(),
361                self.public_registration,
362            ));
363        }
364
365        if self.selected(RouteGroup::Logout) {
366            csrf_protected = csrf_protected.merge(crate::logout_routes::logout_routes());
367        }
368
369        if self.selected(RouteGroup::Settings) {
370            csrf_protected =
371                csrf_protected.merge(crate::settings_routes::settings_routes(is_production));
372        }
373
374        if self.selected(RouteGroup::Consent) {
375            csrf_protected =
376                csrf_protected.merge(crate::consent_routes::consent_routes(is_production));
377        }
378
379        if self.selected(RouteGroup::PasswordReset) {
380            csrf_protected = csrf_protected.merge(
381                crate::password_reset_page_routes::password_reset_page_routes(is_production),
382            );
383        }
384
385        if self.selected(RouteGroup::Mfa) {
386            let base_url = self.base_url.clone().expect("validated above");
387            csrf_protected = csrf_protected.merge(crate::mfa_page_routes::mfa_setup_routes(
388                is_production,
389                base_url,
390            ));
391        }
392
393        // --- Non-CSRF routes ---
394
395        let mut non_csrf: Router<()> = Router::new();
396
397        if self.selected(RouteGroup::Mfa) {
398            non_csrf = non_csrf.merge(crate::mfa_page_routes::mfa_challenge_routes(is_production));
399            let issuer = self.mfa_issuer.take().expect("validated above");
400            non_csrf = non_csrf.merge(crate::mfa_routes::mfa_routes(issuer));
401        }
402
403        if self.selected(RouteGroup::OAuth) {
404            let providers = self.oauth_provider_impls.take().expect("validated above");
405            let base_url = self.base_url.clone().expect("validated above");
406            non_csrf = non_csrf.merge(crate::oauth_routes::oauth_routes(
407                providers,
408                base_url,
409                self.events_tx.clone(),
410            ));
411        }
412
413        // --- OIDC sub-router (CORS-eligible routes) ---
414
415        let mut oidc: Router<()> = Router::new();
416
417        if self.selected(RouteGroup::Token) {
418            oidc = oidc.merge(crate::token_route::token_route());
419        }
420
421        if self.selected(RouteGroup::UserInfo) {
422            oidc = oidc.merge(crate::userinfo_route::userinfo_route());
423        }
424
425        if self.selected(RouteGroup::WellKnown) {
426            let base_url = self.base_url.clone().expect("validated above");
427            oidc = oidc.merge(crate::well_known_routes::well_known_routes(base_url));
428        }
429
430        // cors_middleware reads AllowThem from extensions; the inject shim is
431        // applied by the caller (build/build_for_saas) at the appropriate scope.
432        let oidc_final: Router<()> = if self.cors_enabled {
433            oidc.layer(axum::middleware::from_fn(crate::cors::cors_middleware))
434        } else {
435            oidc
436        };
437
438        non_csrf = non_csrf.merge(oidc_final);
439
440        if self.selected(RouteGroup::PasswordReset) {
441            non_csrf = non_csrf.merge(crate::password_reset_routes::password_reset_routes());
442        }
443
444        let default_branding = self.default_branding.take();
445
446        // Apply CSRF middleware to browser routes, then merge non-CSRF routes.
447        // Both csrf_middleware and all handlers read AllowThem from extensions.
448        // Static assets are merged unconditionally and bypass all middleware.
449        let mut csrf_protected =
450            csrf_protected.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware));
451
452        if let Some(branding) = default_branding {
453            let default = Arc::new(crate::branding::DefaultBranding(branding));
454            csrf_protected = csrf_protected.layer(axum::Extension(default));
455        }
456
457        Ok(csrf_protected
458            .merge(non_csrf)
459            .nest("/static/wavefunk", wavefunk_ui::axum::asset_router())
460            .merge(crate::static_routes::router()))
461    }
462
463    /// Build routes for standalone mode. Wraps `build_inner` with the inject
464    /// shim that bridges `State<AllowThem>` into request extensions.
465    pub fn build(self, ath: &AllowThem) -> Result<Router<()>, AllRoutesError> {
466        let inner = self.build_inner()?;
467        Ok(inner.layer(axum::middleware::from_fn_with_state(
468            ath.clone(),
469            crate::cors::inject_ath_into_extensions,
470        )))
471    }
472
473    /// Build routes for SaaS mode. The tenant router injects AllowThem into
474    /// extensions before dispatching, so no inject shim is added here.
475    pub fn build_for_saas(self) -> Result<Router<()>, AllRoutesError> {
476        self.build_inner()
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use allowthem_core::AllowThemBuilder;
484    use axum::body::{self, Body};
485    use axum::http::{Request, StatusCode};
486    use tower::ServiceExt;
487
488    #[tokio::test]
489    async fn build_fails_no_routes_selected() {
490        let ath = AllowThemBuilder::new("sqlite::memory:")
491            .build()
492            .await
493            .unwrap();
494        let result = AllRoutesBuilder::new().build(&ath);
495        assert!(matches!(result, Err(AllRoutesError::NoRoutesSelected)));
496    }
497
498    #[tokio::test]
499    async fn build_fails_oauth_without_providers() {
500        let ath = AllowThemBuilder::new("sqlite::memory:")
501            .build()
502            .await
503            .unwrap();
504        let result = AllRoutesBuilder::new()
505            .base_url("http://localhost")
506            .oauth()
507            .build(&ath);
508        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
509    }
510
511    #[tokio::test]
512    async fn build_fails_mfa_without_issuer() {
513        let ath = AllowThemBuilder::new("sqlite::memory:")
514            .build()
515            .await
516            .unwrap();
517        let result = AllRoutesBuilder::new()
518            .base_url("http://localhost")
519            .mfa()
520            .build(&ath);
521        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
522    }
523
524    #[tokio::test]
525    async fn build_fails_events_without_base_url() {
526        let ath = AllowThemBuilder::new("sqlite::memory:")
527            .csrf_key(*b"test-csrf-key-for-server-tests!!")
528            .build()
529            .await
530            .unwrap();
531        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
532        let result = AllRoutesBuilder::new().events(tx).register().build(&ath);
533        match result {
534            Err(AllRoutesError::MissingConfig(msg)) => assert!(msg.contains("base_url")),
535            other => panic!("expected MissingConfig, got {other:?}"),
536        }
537    }
538
539    #[tokio::test]
540    async fn build_fails_missing_base_url() {
541        let ath = AllowThemBuilder::new("sqlite::memory:")
542            .build()
543            .await
544            .unwrap();
545        let result = AllRoutesBuilder::new().well_known().build(&ath);
546        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
547    }
548
549    #[tokio::test]
550    async fn build_with_invalid_schema_returns_error() {
551        let ath = AllowThemBuilder::new("sqlite::memory:")
552            .csrf_key(*b"test-csrf-key-for-server-tests!!")
553            .build()
554            .await
555            .unwrap();
556        // A schema with type "array" is not a valid custom fields schema
557        let schema = serde_json::json!({"type": "array"});
558        let result = AllRoutesBuilder::new()
559            .register()
560            .custom_fields_schema(schema)
561            .build(&ath);
562        assert!(matches!(result, Err(AllRoutesError::InvalidSchema(_))));
563    }
564
565    #[test]
566    fn default_branding_setter_stores_value() {
567        use allowthem_core::applications::BrandingConfig;
568        let builder = AllRoutesBuilder::new().default_branding(BrandingConfig::new("Fixture Co"));
569        assert_eq!(
570            builder
571                .default_branding
572                .as_ref()
573                .map(|b| b.application_name.as_str()),
574            Some("Fixture Co")
575        );
576    }
577
578    #[test]
579    fn default_branding_absent_by_default() {
580        let builder = AllRoutesBuilder::new();
581        assert!(builder.default_branding.is_none());
582    }
583
584    #[tokio::test]
585    async fn build_succeeds_for_simple_routes() {
586        let ath = AllowThemBuilder::new("sqlite::memory:")
587            .csrf_key(*b"test-csrf-key-for-server-tests!!")
588            .build()
589            .await
590            .unwrap();
591        // Routes that require no extra config beyond defaults
592        let result = AllRoutesBuilder::new()
593            .login()
594            .register()
595            .logout()
596            .settings()
597            .consent()
598            .token()
599            .userinfo()
600            .build(&ath);
601        assert!(result.is_ok());
602    }
603
604    #[tokio::test]
605    async fn built_routes_serve_wavefunk_ui_assets() {
606        let ath = AllowThemBuilder::new("sqlite::memory:")
607            .csrf_key(*b"test-csrf-key-for-server-tests!!")
608            .build()
609            .await
610            .unwrap();
611        let app = AllRoutesBuilder::new()
612            .login()
613            .build(&ath)
614            .expect("routes should build");
615
616        for (path, content_type, expected_text) in [
617            (
618                "/static/wavefunk/css/wavefunk.css",
619                "text/css; charset=utf-8",
620                "@import url(\"./04-components.css\")",
621            ),
622            (
623                "/static/wavefunk/js/wavefunk.js",
624                "text/javascript; charset=utf-8",
625                "data-wf-copy",
626            ),
627        ] {
628            let response = app
629                .clone()
630                .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
631                .await
632                .unwrap();
633            assert_eq!(response.status(), StatusCode::OK, "{path} should be served");
634            assert_eq!(
635                response.headers().get("content-type").unwrap(),
636                content_type
637            );
638            let body = body::to_bytes(response.into_body(), usize::MAX)
639                .await
640                .unwrap();
641            let text = std::str::from_utf8(&body).expect("asset should be UTF-8");
642            assert!(
643                text.contains(expected_text),
644                "{path} missing expected marker {expected_text}"
645            );
646        }
647    }
648}