Skip to main content

allowthem_server/
all_routes.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use axum::Router;
6use minijinja::Environment;
7
8use allowthem_core::{AllowThem, LifecycleEventSender, OAuthProvider};
9
10use crate::browser_templates::build_default_browser_env;
11
12/// Identifies a logical group of routes that can be selectively enabled.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum RouteGroup {
15    Login,
16    Register,
17    Logout,
18    Settings,
19    Consent,
20    PasswordReset,
21    Mfa,
22    OAuth,
23    Token,
24    UserInfo,
25    WellKnown,
26}
27
28/// Errors returned by [`AllRoutesBuilder::build`].
29#[derive(Debug)]
30pub enum AllRoutesError {
31    NoRoutesSelected,
32    MissingConfig(String),
33    InvalidSchema(String),
34}
35
36impl fmt::Display for AllRoutesError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::NoRoutesSelected => f.write_str("no route groups selected"),
40            Self::MissingConfig(msg) => write!(f, "missing config: {msg}"),
41            Self::InvalidSchema(msg) => write!(f, "invalid custom fields schema: {msg}"),
42        }
43    }
44}
45
46impl std::error::Error for AllRoutesError {}
47
48/// Builder that assembles allowthem route groups into a single
49/// [`Router<()>`] with CSRF middleware applied to the correct subset.
50pub struct AllRoutesBuilder {
51    // Shared config
52    templates: Option<Arc<Environment<'static>>>,
53    is_production: bool,
54    base_url: Option<String>,
55
56    // Login-specific
57    max_login_attempts: u32,
58    rate_limit_window_secs: u64,
59    oauth_providers_list: Option<Vec<String>>,
60    login_overrides: Option<crate::login_routes::LoginOverrides>,
61
62    // OAuth-specific
63    oauth_provider_impls: Option<HashMap<String, Box<dyn OAuthProvider>>>,
64
65    // MFA-specific
66    mfa_issuer: Option<String>,
67
68    // Register-specific
69    custom_fields_schema: Option<serde_json::Value>,
70    public_registration: bool,
71
72    // Event publishing (optional)
73    events_tx: Option<LifecycleEventSender>,
74
75    // Route selection
76    routes: HashSet<RouteGroup>,
77    all: bool,
78
79    // CORS for OIDC endpoints
80    cors_enabled: bool,
81
82    // Embedder fallback branding
83    default_branding: Option<allowthem_core::applications::BrandingConfig>,
84}
85
86impl Default for AllRoutesBuilder {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl AllRoutesBuilder {
93    pub fn new() -> Self {
94        Self {
95            templates: None,
96            is_production: false,
97            base_url: None,
98            max_login_attempts: 10,
99            rate_limit_window_secs: 900,
100            oauth_providers_list: None,
101            login_overrides: None,
102            oauth_provider_impls: None,
103            mfa_issuer: None,
104            custom_fields_schema: None,
105            public_registration: true,
106            events_tx: None,
107            routes: HashSet::new(),
108            all: false,
109            cors_enabled: false,
110            default_branding: None,
111        }
112    }
113
114    // --- Shared config ---
115
116    pub fn templates(mut self, templates: Arc<Environment<'static>>) -> Self {
117        self.templates = Some(templates);
118        self
119    }
120
121    pub fn is_production(mut self, is_production: bool) -> Self {
122        self.is_production = is_production;
123        self
124    }
125
126    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
127        self.base_url = Some(base_url.into());
128        self
129    }
130
131    /// Attach a channel sender that receives lifecycle events (register, etc.).
132    ///
133    /// See `docs/superpowers/specs/2026-04-20-lifecycle-events-design.md` for
134    /// the delivery contract. Called at most once; subsequent calls overwrite.
135    pub fn events(mut self, tx: LifecycleEventSender) -> Self {
136        self.events_tx = Some(tx);
137        self
138    }
139
140    /// Set a fallback branding used for every pre-auth page when no per-client
141    /// lookup matches (no `client_id` query param, or no row with that id).
142    /// Embedders typically call this once at startup with their own app name,
143    /// accent, and splash metadata.
144    pub fn default_branding(
145        mut self,
146        branding: allowthem_core::applications::BrandingConfig,
147    ) -> Self {
148        self.default_branding = Some(branding);
149        self
150    }
151
152    // --- Login config ---
153
154    pub fn max_login_attempts(mut self, max: u32) -> Self {
155        self.max_login_attempts = max;
156        self
157    }
158
159    pub fn rate_limit_window_secs(mut self, secs: u64) -> Self {
160        self.rate_limit_window_secs = secs;
161        self
162    }
163
164    pub fn oauth_providers_list(mut self, providers: Vec<String>) -> Self {
165        self.oauth_providers_list = Some(providers);
166        self
167    }
168
169    pub fn login_overrides(mut self, overrides: crate::login_routes::LoginOverrides) -> Self {
170        self.login_overrides = Some(overrides);
171        self
172    }
173
174    // --- OAuth config ---
175
176    pub fn oauth_providers(mut self, providers: HashMap<String, Box<dyn OAuthProvider>>) -> Self {
177        self.oauth_provider_impls = Some(providers);
178        self
179    }
180
181    // --- MFA config ---
182
183    pub fn mfa_issuer(mut self, issuer: impl Into<String>) -> Self {
184        self.mfa_issuer = Some(issuer.into());
185        self
186    }
187
188    // --- Register config ---
189
190    pub fn custom_fields_schema(mut self, schema: serde_json::Value) -> Self {
191        self.custom_fields_schema = Some(schema);
192        self
193    }
194
195    pub fn public_registration(mut self, enabled: bool) -> Self {
196        self.public_registration = enabled;
197        self
198    }
199
200    // --- Route selectors ---
201
202    pub fn login(mut self) -> Self {
203        self.routes.insert(RouteGroup::Login);
204        self
205    }
206
207    pub fn register(mut self) -> Self {
208        self.routes.insert(RouteGroup::Register);
209        self
210    }
211
212    pub fn logout(mut self) -> Self {
213        self.routes.insert(RouteGroup::Logout);
214        self
215    }
216
217    pub fn settings(mut self) -> Self {
218        self.routes.insert(RouteGroup::Settings);
219        self
220    }
221
222    pub fn consent(mut self) -> Self {
223        self.routes.insert(RouteGroup::Consent);
224        self
225    }
226
227    pub fn password_reset(mut self) -> Self {
228        self.routes.insert(RouteGroup::PasswordReset);
229        self
230    }
231
232    pub fn mfa(mut self) -> Self {
233        self.routes.insert(RouteGroup::Mfa);
234        self
235    }
236
237    pub fn oauth(mut self) -> Self {
238        self.routes.insert(RouteGroup::OAuth);
239        self
240    }
241
242    pub fn token(mut self) -> Self {
243        self.routes.insert(RouteGroup::Token);
244        self
245    }
246
247    pub fn userinfo(mut self) -> Self {
248        self.routes.insert(RouteGroup::UserInfo);
249        self
250    }
251
252    pub fn well_known(mut self) -> Self {
253        self.routes.insert(RouteGroup::WellKnown);
254        self
255    }
256
257    pub fn all_routes(mut self) -> Self {
258        self.all = true;
259        self
260    }
261
262    /// Enable dynamic CORS for OIDC endpoints (`/oauth/token`, `/oauth/userinfo`,
263    /// `/.well-known/*`). Allowed origins are derived per-request from active
264    /// application redirect URIs. Has no effect unless Token, UserInfo, or
265    /// WellKnown routes are also selected.
266    pub fn cors(mut self) -> Self {
267        self.cors_enabled = true;
268        self
269    }
270
271    // --- Build ---
272
273    fn selected(&self, group: RouteGroup) -> bool {
274        self.all || self.routes.contains(&group)
275    }
276
277    fn validate(&self) -> Result<(), AllRoutesError> {
278        if !self.all && self.routes.is_empty() {
279            return Err(AllRoutesError::NoRoutesSelected);
280        }
281
282        let needs_base_url = self.selected(RouteGroup::PasswordReset)
283            || self.selected(RouteGroup::Mfa)
284            || self.selected(RouteGroup::OAuth)
285            || self.selected(RouteGroup::WellKnown);
286
287        if needs_base_url && self.base_url.is_none() {
288            return Err(AllRoutesError::MissingConfig(
289                "base_url required by selected route groups".into(),
290            ));
291        }
292
293        if self.events_tx.is_some() && self.base_url.is_none() {
294            return Err(AllRoutesError::MissingConfig(
295                "base_url required when events channel is configured".into(),
296            ));
297        }
298
299        if self.selected(RouteGroup::OAuth) && self.oauth_provider_impls.is_none() {
300            return Err(AllRoutesError::MissingConfig(
301                "oauth_providers required when oauth routes are selected".into(),
302            ));
303        }
304
305        if self.selected(RouteGroup::Mfa) && self.mfa_issuer.is_none() {
306            return Err(AllRoutesError::MissingConfig(
307                "mfa_issuer required when mfa routes are selected".into(),
308            ));
309        }
310
311        Ok(())
312    }
313
314    fn build_inner(mut self) -> Result<Router<()>, AllRoutesError> {
315        self.validate()?;
316
317        // --- Resolve defaults ---
318
319        let templates = self
320            .templates
321            .take()
322            .unwrap_or_else(build_default_browser_env);
323        let is_production = self.is_production;
324
325        // Derive oauth_providers_list from the provider map keys when not
326        // explicitly set. This avoids requiring the caller to duplicate the
327        // provider names.
328        let oauth_providers_list = self
329            .oauth_providers_list
330            .take()
331            .unwrap_or_else(|| match &self.oauth_provider_impls {
332                Some(map) => {
333                    let mut names: Vec<String> = map.keys().cloned().collect();
334                    names.sort();
335                    names
336                }
337                None => Vec::new(),
338            });
339
340        // --- CSRF-protected routes (browser routes) ---
341
342        let mut csrf_protected: Router<()> = Router::new();
343
344        if self.selected(RouteGroup::Login) {
345            csrf_protected = csrf_protected.merge(crate::login_routes::login_routes(
346                templates.clone(),
347                is_production,
348                self.max_login_attempts,
349                self.rate_limit_window_secs,
350                oauth_providers_list.clone(),
351                self.login_overrides.take(),
352            ));
353        }
354
355        if self.selected(RouteGroup::Register) {
356            let custom_schema = if let Some(schema) = self.custom_fields_schema.take() {
357                crate::custom_fields::validate_custom_schema(&schema)
358                    .map_err(AllRoutesError::InvalidSchema)?;
359                let validator = jsonschema::validator_for(&schema)
360                    .map_err(|e| AllRoutesError::InvalidSchema(e.to_string()))?;
361                let fields = crate::custom_fields::extract_field_descriptors(&schema);
362                Some(crate::custom_fields::CustomSchemaConfig {
363                    schema,
364                    validator,
365                    fields,
366                })
367            } else {
368                None
369            };
370            csrf_protected = csrf_protected.merge(crate::register_routes::register_routes(
371                templates.clone(),
372                is_production,
373                custom_schema,
374                self.events_tx.clone(),
375                self.base_url.clone(),
376                oauth_providers_list.clone(),
377                self.public_registration,
378            ));
379        }
380
381        if self.selected(RouteGroup::Logout) {
382            csrf_protected = csrf_protected.merge(crate::logout_routes::logout_routes());
383        }
384
385        if self.selected(RouteGroup::Settings) {
386            csrf_protected = csrf_protected.merge(crate::settings_routes::settings_routes(
387                templates.clone(),
388                is_production,
389            ));
390        }
391
392        if self.selected(RouteGroup::Consent) {
393            csrf_protected = csrf_protected.merge(crate::consent_routes::consent_routes(
394                templates.clone(),
395                is_production,
396            ));
397        }
398
399        if self.selected(RouteGroup::PasswordReset) {
400            csrf_protected = csrf_protected.merge(
401                crate::password_reset_page_routes::password_reset_page_routes(
402                    templates.clone(),
403                    is_production,
404                ),
405            );
406        }
407
408        if self.selected(RouteGroup::Mfa) {
409            let base_url = self.base_url.clone().expect("validated above");
410            csrf_protected = csrf_protected.merge(crate::mfa_page_routes::mfa_setup_routes(
411                templates.clone(),
412                is_production,
413                base_url,
414            ));
415        }
416
417        // --- Non-CSRF routes ---
418
419        let mut non_csrf: Router<()> = Router::new();
420
421        if self.selected(RouteGroup::Mfa) {
422            non_csrf = non_csrf.merge(crate::mfa_page_routes::mfa_challenge_routes(
423                templates.clone(),
424                is_production,
425            ));
426            let issuer = self.mfa_issuer.take().expect("validated above");
427            non_csrf = non_csrf.merge(crate::mfa_routes::mfa_routes(issuer));
428        }
429
430        if self.selected(RouteGroup::OAuth) {
431            let providers = self.oauth_provider_impls.take().expect("validated above");
432            let base_url = self.base_url.clone().expect("validated above");
433            non_csrf = non_csrf.merge(crate::oauth_routes::oauth_routes(
434                providers,
435                base_url,
436                self.events_tx.clone(),
437            ));
438        }
439
440        // --- OIDC sub-router (CORS-eligible routes) ---
441
442        let mut oidc: Router<()> = Router::new();
443
444        if self.selected(RouteGroup::Token) {
445            oidc = oidc.merge(crate::token_route::token_route());
446        }
447
448        if self.selected(RouteGroup::UserInfo) {
449            oidc = oidc.merge(crate::userinfo_route::userinfo_route());
450        }
451
452        if self.selected(RouteGroup::WellKnown) {
453            let base_url = self.base_url.clone().expect("validated above");
454            oidc = oidc.merge(crate::well_known_routes::well_known_routes(base_url));
455        }
456
457        // cors_middleware reads AllowThem from extensions; the inject shim is
458        // applied by the caller (build/build_for_saas) at the appropriate scope.
459        let oidc_final: Router<()> = if self.cors_enabled {
460            oidc.layer(axum::middleware::from_fn(crate::cors::cors_middleware))
461        } else {
462            oidc
463        };
464
465        non_csrf = non_csrf.merge(oidc_final);
466
467        if self.selected(RouteGroup::PasswordReset) {
468            non_csrf = non_csrf.merge(crate::password_reset_routes::password_reset_routes());
469        }
470
471        let default_branding = self.default_branding.take();
472
473        // Apply CSRF middleware to browser routes, then merge non-CSRF routes.
474        // Both csrf_middleware and all handlers read AllowThem from extensions.
475        // Static assets are merged unconditionally and bypass all middleware.
476        let mut csrf_protected =
477            csrf_protected.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware));
478
479        if let Some(branding) = default_branding {
480            let default = Arc::new(crate::branding::DefaultBranding(branding));
481            csrf_protected = csrf_protected.layer(axum::Extension(default));
482        }
483
484        Ok(csrf_protected
485            .merge(non_csrf)
486            .merge(crate::static_routes::router()))
487    }
488
489    /// Build routes for standalone mode. Wraps `build_inner` with the inject
490    /// shim that bridges `State<AllowThem>` into request extensions.
491    pub fn build(self, ath: &AllowThem) -> Result<Router<()>, AllRoutesError> {
492        let inner = self.build_inner()?;
493        Ok(inner.layer(axum::middleware::from_fn_with_state(
494            ath.clone(),
495            crate::cors::inject_ath_into_extensions,
496        )))
497    }
498
499    /// Build routes for SaaS mode. The tenant router injects AllowThem into
500    /// extensions before dispatching, so no inject shim is added here.
501    pub fn build_for_saas(self) -> Result<Router<()>, AllRoutesError> {
502        self.build_inner()
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use allowthem_core::AllowThemBuilder;
510
511    #[tokio::test]
512    async fn build_fails_no_routes_selected() {
513        let ath = AllowThemBuilder::new("sqlite::memory:")
514            .build()
515            .await
516            .unwrap();
517        let result = AllRoutesBuilder::new().build(&ath);
518        assert!(matches!(result, Err(AllRoutesError::NoRoutesSelected)));
519    }
520
521    #[tokio::test]
522    async fn build_fails_oauth_without_providers() {
523        let ath = AllowThemBuilder::new("sqlite::memory:")
524            .build()
525            .await
526            .unwrap();
527        let result = AllRoutesBuilder::new()
528            .base_url("http://localhost")
529            .oauth()
530            .build(&ath);
531        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
532    }
533
534    #[tokio::test]
535    async fn build_fails_mfa_without_issuer() {
536        let ath = AllowThemBuilder::new("sqlite::memory:")
537            .build()
538            .await
539            .unwrap();
540        let result = AllRoutesBuilder::new()
541            .base_url("http://localhost")
542            .mfa()
543            .build(&ath);
544        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
545    }
546
547    #[tokio::test]
548    async fn build_fails_events_without_base_url() {
549        let ath = AllowThemBuilder::new("sqlite::memory:")
550            .csrf_key(*b"test-csrf-key-for-server-tests!!")
551            .build()
552            .await
553            .unwrap();
554        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
555        let result = AllRoutesBuilder::new().events(tx).register().build(&ath);
556        match result {
557            Err(AllRoutesError::MissingConfig(msg)) => assert!(msg.contains("base_url")),
558            other => panic!("expected MissingConfig, got {other:?}"),
559        }
560    }
561
562    #[tokio::test]
563    async fn build_fails_missing_base_url() {
564        let ath = AllowThemBuilder::new("sqlite::memory:")
565            .build()
566            .await
567            .unwrap();
568        let result = AllRoutesBuilder::new().well_known().build(&ath);
569        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
570    }
571
572    #[tokio::test]
573    async fn build_with_invalid_schema_returns_error() {
574        let ath = AllowThemBuilder::new("sqlite::memory:")
575            .csrf_key(*b"test-csrf-key-for-server-tests!!")
576            .build()
577            .await
578            .unwrap();
579        // A schema with type "array" is not a valid custom fields schema
580        let schema = serde_json::json!({"type": "array"});
581        let result = AllRoutesBuilder::new()
582            .register()
583            .custom_fields_schema(schema)
584            .build(&ath);
585        assert!(matches!(result, Err(AllRoutesError::InvalidSchema(_))));
586    }
587
588    #[test]
589    fn default_branding_setter_stores_value() {
590        use allowthem_core::applications::BrandingConfig;
591        let builder = AllRoutesBuilder::new().default_branding(BrandingConfig::new("Fixture Co"));
592        assert_eq!(
593            builder
594                .default_branding
595                .as_ref()
596                .map(|b| b.application_name.as_str()),
597            Some("Fixture Co")
598        );
599    }
600
601    #[test]
602    fn default_branding_absent_by_default() {
603        let builder = AllRoutesBuilder::new();
604        assert!(builder.default_branding.is_none());
605    }
606
607    #[tokio::test]
608    async fn build_succeeds_for_simple_routes() {
609        let ath = AllowThemBuilder::new("sqlite::memory:")
610            .csrf_key(*b"test-csrf-key-for-server-tests!!")
611            .build()
612            .await
613            .unwrap();
614        // Routes that require no extra config beyond defaults
615        let result = AllRoutesBuilder::new()
616            .login()
617            .register()
618            .logout()
619            .settings()
620            .consent()
621            .token()
622            .userinfo()
623            .build(&ath);
624        assert!(result.is_ok());
625    }
626}