Skip to main content

allowthem_server/
all_routes.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use axum::Router;
6use minijinja::Environment;
7
8use allowthem_core::{AllowThem, AuthEventSender, EmailSender, OAuthProvider};
9
10use crate::browser_templates::build_default_browser_env;
11
12/// Identifies a logical group of routes that can be selectively enabled.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum RouteGroup {
15    Login,
16    Register,
17    Logout,
18    Settings,
19    Consent,
20    PasswordReset,
21    Mfa,
22    OAuth,
23    Token,
24    UserInfo,
25    WellKnown,
26}
27
28/// Errors returned by [`AllRoutesBuilder::build`].
29#[derive(Debug)]
30pub enum AllRoutesError {
31    NoRoutesSelected,
32    MissingConfig(String),
33    InvalidSchema(String),
34}
35
36impl fmt::Display for AllRoutesError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::NoRoutesSelected => f.write_str("no route groups selected"),
40            Self::MissingConfig(msg) => write!(f, "missing config: {msg}"),
41            Self::InvalidSchema(msg) => write!(f, "invalid custom fields schema: {msg}"),
42        }
43    }
44}
45
46impl std::error::Error for AllRoutesError {}
47
48/// Builder that assembles allowthem route groups into a single
49/// [`Router<()>`] with CSRF middleware applied to the correct subset.
50pub struct AllRoutesBuilder {
51    // Shared config
52    templates: Option<Arc<Environment<'static>>>,
53    is_production: bool,
54    base_url: Option<String>,
55    email_sender: Option<Arc<dyn EmailSender>>,
56
57    // Login-specific
58    max_login_attempts: u32,
59    rate_limit_window_secs: u64,
60    oauth_providers_list: Option<Vec<String>>,
61
62    // OAuth-specific
63    oauth_provider_impls: Option<HashMap<String, Box<dyn OAuthProvider>>>,
64
65    // MFA-specific
66    mfa_issuer: Option<String>,
67
68    // Register-specific
69    custom_fields_schema: Option<serde_json::Value>,
70
71    // Event publishing (optional)
72    events_tx: Option<AuthEventSender>,
73
74    // Route selection
75    routes: HashSet<RouteGroup>,
76    all: bool,
77
78    // CORS for OIDC endpoints
79    cors_enabled: bool,
80
81    // Embedder fallback branding
82    default_branding: Option<allowthem_core::applications::BrandingConfig>,
83}
84
85impl Default for AllRoutesBuilder {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl AllRoutesBuilder {
92    pub fn new() -> Self {
93        Self {
94            templates: None,
95            is_production: false,
96            base_url: None,
97            email_sender: None,
98            max_login_attempts: 10,
99            rate_limit_window_secs: 900,
100            oauth_providers_list: None,
101            oauth_provider_impls: None,
102            mfa_issuer: None,
103            custom_fields_schema: None,
104            events_tx: None,
105            routes: HashSet::new(),
106            all: false,
107            cors_enabled: false,
108            default_branding: None,
109        }
110    }
111
112    // --- Shared config ---
113
114    pub fn templates(mut self, templates: Arc<Environment<'static>>) -> Self {
115        self.templates = Some(templates);
116        self
117    }
118
119    pub fn is_production(mut self, is_production: bool) -> Self {
120        self.is_production = is_production;
121        self
122    }
123
124    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
125        self.base_url = Some(base_url.into());
126        self
127    }
128
129    pub fn email_sender(mut self, sender: Arc<dyn EmailSender>) -> Self {
130        self.email_sender = Some(sender);
131        self
132    }
133
134    /// Attach a channel sender that receives lifecycle events (register, etc.).
135    ///
136    /// See `docs/superpowers/specs/2026-04-20-lifecycle-events-design.md` for
137    /// the delivery contract. Called at most once; subsequent calls overwrite.
138    pub fn events(mut self, tx: AuthEventSender) -> Self {
139        self.events_tx = Some(tx);
140        self
141    }
142
143    /// Set a fallback branding used for every pre-auth page when no per-client
144    /// lookup matches (no `client_id` query param, or no row with that id).
145    /// Embedders typically call this once at startup with their own app name,
146    /// accent, and splash metadata.
147    pub fn default_branding(
148        mut self,
149        branding: allowthem_core::applications::BrandingConfig,
150    ) -> Self {
151        self.default_branding = Some(branding);
152        self
153    }
154
155    // --- Login config ---
156
157    pub fn max_login_attempts(mut self, max: u32) -> Self {
158        self.max_login_attempts = max;
159        self
160    }
161
162    pub fn rate_limit_window_secs(mut self, secs: u64) -> Self {
163        self.rate_limit_window_secs = secs;
164        self
165    }
166
167    pub fn oauth_providers_list(mut self, providers: Vec<String>) -> Self {
168        self.oauth_providers_list = Some(providers);
169        self
170    }
171
172    // --- OAuth config ---
173
174    pub fn oauth_providers(mut self, providers: HashMap<String, Box<dyn OAuthProvider>>) -> Self {
175        self.oauth_provider_impls = Some(providers);
176        self
177    }
178
179    // --- MFA config ---
180
181    pub fn mfa_issuer(mut self, issuer: impl Into<String>) -> Self {
182        self.mfa_issuer = Some(issuer.into());
183        self
184    }
185
186    // --- Register config ---
187
188    pub fn custom_fields_schema(mut self, schema: serde_json::Value) -> Self {
189        self.custom_fields_schema = Some(schema);
190        self
191    }
192
193    // --- Route selectors ---
194
195    pub fn login(mut self) -> Self {
196        self.routes.insert(RouteGroup::Login);
197        self
198    }
199
200    pub fn register(mut self) -> Self {
201        self.routes.insert(RouteGroup::Register);
202        self
203    }
204
205    pub fn logout(mut self) -> Self {
206        self.routes.insert(RouteGroup::Logout);
207        self
208    }
209
210    pub fn settings(mut self) -> Self {
211        self.routes.insert(RouteGroup::Settings);
212        self
213    }
214
215    pub fn consent(mut self) -> Self {
216        self.routes.insert(RouteGroup::Consent);
217        self
218    }
219
220    pub fn password_reset(mut self) -> Self {
221        self.routes.insert(RouteGroup::PasswordReset);
222        self
223    }
224
225    pub fn mfa(mut self) -> Self {
226        self.routes.insert(RouteGroup::Mfa);
227        self
228    }
229
230    pub fn oauth(mut self) -> Self {
231        self.routes.insert(RouteGroup::OAuth);
232        self
233    }
234
235    pub fn token(mut self) -> Self {
236        self.routes.insert(RouteGroup::Token);
237        self
238    }
239
240    pub fn userinfo(mut self) -> Self {
241        self.routes.insert(RouteGroup::UserInfo);
242        self
243    }
244
245    pub fn well_known(mut self) -> Self {
246        self.routes.insert(RouteGroup::WellKnown);
247        self
248    }
249
250    pub fn all_routes(mut self) -> Self {
251        self.all = true;
252        self
253    }
254
255    /// Enable dynamic CORS for OIDC endpoints (`/oauth/token`, `/oauth/userinfo`,
256    /// `/.well-known/*`). Allowed origins are derived per-request from active
257    /// application redirect URIs. Has no effect unless Token, UserInfo, or
258    /// WellKnown routes are also selected.
259    pub fn cors(mut self) -> Self {
260        self.cors_enabled = true;
261        self
262    }
263
264    // --- Build ---
265
266    fn selected(&self, group: RouteGroup) -> bool {
267        self.all || self.routes.contains(&group)
268    }
269
270    fn validate(&self) -> Result<(), AllRoutesError> {
271        if !self.all && self.routes.is_empty() {
272            return Err(AllRoutesError::NoRoutesSelected);
273        }
274
275        let needs_base_url = self.selected(RouteGroup::PasswordReset)
276            || self.selected(RouteGroup::Mfa)
277            || self.selected(RouteGroup::OAuth)
278            || self.selected(RouteGroup::WellKnown);
279
280        if needs_base_url && self.base_url.is_none() {
281            return Err(AllRoutesError::MissingConfig(
282                "base_url required by selected route groups".into(),
283            ));
284        }
285
286        if self.events_tx.is_some() && self.base_url.is_none() {
287            return Err(AllRoutesError::MissingConfig(
288                "base_url required when events channel is configured".into(),
289            ));
290        }
291
292        if self.selected(RouteGroup::OAuth) && self.oauth_provider_impls.is_none() {
293            return Err(AllRoutesError::MissingConfig(
294                "oauth_providers required when oauth routes are selected".into(),
295            ));
296        }
297
298        if self.selected(RouteGroup::PasswordReset) && self.email_sender.is_none() {
299            return Err(AllRoutesError::MissingConfig(
300                "email_sender required when password_reset routes are selected".into(),
301            ));
302        }
303
304        if self.selected(RouteGroup::Mfa) && self.mfa_issuer.is_none() {
305            return Err(AllRoutesError::MissingConfig(
306                "mfa_issuer required when mfa routes are selected".into(),
307            ));
308        }
309
310        Ok(())
311    }
312
313    fn build_inner(mut self) -> Result<Router<()>, AllRoutesError> {
314        self.validate()?;
315
316        // --- Resolve defaults ---
317
318        let templates = self
319            .templates
320            .take()
321            .unwrap_or_else(build_default_browser_env);
322        let is_production = self.is_production;
323
324        // Derive oauth_providers_list from the provider map keys when not
325        // explicitly set. This avoids requiring the caller to duplicate the
326        // provider names.
327        let oauth_providers_list = self
328            .oauth_providers_list
329            .take()
330            .unwrap_or_else(|| match &self.oauth_provider_impls {
331                Some(map) => {
332                    let mut names: Vec<String> = map.keys().cloned().collect();
333                    names.sort();
334                    names
335                }
336                None => Vec::new(),
337            });
338
339        // --- CSRF-protected routes (browser routes) ---
340
341        let mut csrf_protected: Router<()> = Router::new();
342
343        if self.selected(RouteGroup::Login) {
344            csrf_protected = csrf_protected.merge(crate::login_routes::login_routes(
345                templates.clone(),
346                is_production,
347                self.max_login_attempts,
348                self.rate_limit_window_secs,
349                oauth_providers_list.clone(),
350            ));
351        }
352
353        if self.selected(RouteGroup::Register) {
354            let custom_schema = if let Some(schema) = self.custom_fields_schema.take() {
355                crate::custom_fields::validate_custom_schema(&schema)
356                    .map_err(AllRoutesError::InvalidSchema)?;
357                let validator = jsonschema::validator_for(&schema)
358                    .map_err(|e| AllRoutesError::InvalidSchema(e.to_string()))?;
359                let fields = crate::custom_fields::extract_field_descriptors(&schema);
360                Some(crate::custom_fields::CustomSchemaConfig {
361                    schema,
362                    validator,
363                    fields,
364                })
365            } else {
366                None
367            };
368            csrf_protected = csrf_protected.merge(crate::register_routes::register_routes(
369                templates.clone(),
370                is_production,
371                custom_schema,
372                self.events_tx.clone(),
373                self.base_url.clone(),
374                oauth_providers_list.clone(),
375            ));
376        }
377
378        if self.selected(RouteGroup::Logout) {
379            csrf_protected = csrf_protected.merge(crate::logout_routes::logout_routes());
380        }
381
382        if self.selected(RouteGroup::Settings) {
383            csrf_protected = csrf_protected.merge(crate::settings_routes::settings_routes(
384                templates.clone(),
385                is_production,
386            ));
387        }
388
389        if self.selected(RouteGroup::Consent) {
390            csrf_protected = csrf_protected.merge(crate::consent_routes::consent_routes(
391                templates.clone(),
392                is_production,
393            ));
394        }
395
396        if self.selected(RouteGroup::PasswordReset) {
397            let email_sender = self.email_sender.clone().expect("validated above");
398            let base_url = self.base_url.clone().expect("validated above");
399            csrf_protected = csrf_protected.merge(
400                crate::password_reset_page_routes::password_reset_page_routes(
401                    templates.clone(),
402                    is_production,
403                    email_sender,
404                    base_url,
405                ),
406            );
407        }
408
409        if self.selected(RouteGroup::Mfa) {
410            let base_url = self.base_url.clone().expect("validated above");
411            csrf_protected = csrf_protected.merge(crate::mfa_page_routes::mfa_setup_routes(
412                templates.clone(),
413                is_production,
414                base_url,
415            ));
416        }
417
418        // --- Non-CSRF routes ---
419
420        let mut non_csrf: Router<()> = Router::new();
421
422        if self.selected(RouteGroup::Mfa) {
423            non_csrf = non_csrf.merge(crate::mfa_page_routes::mfa_challenge_routes(
424                templates.clone(),
425                is_production,
426            ));
427            let issuer = self.mfa_issuer.take().expect("validated above");
428            non_csrf = non_csrf.merge(crate::mfa_routes::mfa_routes(issuer));
429        }
430
431        if self.selected(RouteGroup::OAuth) {
432            let providers = self.oauth_provider_impls.take().expect("validated above");
433            let base_url = self.base_url.clone().expect("validated above");
434            non_csrf = non_csrf.merge(crate::oauth_routes::oauth_routes(
435                providers,
436                base_url,
437                self.events_tx.clone(),
438            ));
439        }
440
441        // --- OIDC sub-router (CORS-eligible routes) ---
442
443        let mut oidc: Router<()> = Router::new();
444
445        if self.selected(RouteGroup::Token) {
446            oidc = oidc.merge(crate::token_route::token_route());
447        }
448
449        if self.selected(RouteGroup::UserInfo) {
450            oidc = oidc.merge(crate::userinfo_route::userinfo_route());
451        }
452
453        if self.selected(RouteGroup::WellKnown) {
454            let base_url = self.base_url.clone().expect("validated above");
455            oidc = oidc.merge(crate::well_known_routes::well_known_routes(base_url));
456        }
457
458        // cors_middleware reads AllowThem from extensions; the inject shim is
459        // applied by the caller (build/build_for_saas) at the appropriate scope.
460        let oidc_final: Router<()> = if self.cors_enabled {
461            oidc.layer(axum::middleware::from_fn(crate::cors::cors_middleware))
462        } else {
463            oidc
464        };
465
466        non_csrf = non_csrf.merge(oidc_final);
467
468        if self.selected(RouteGroup::PasswordReset) {
469            let email_sender = self.email_sender.take().expect("validated above");
470            let base_url = self.base_url.expect("validated above");
471            non_csrf = non_csrf.merge(crate::password_reset_routes::password_reset_routes(
472                email_sender,
473                base_url,
474            ));
475        }
476
477        let default_branding = self.default_branding.take();
478
479        // Apply CSRF middleware to browser routes, then merge non-CSRF routes.
480        // Both csrf_middleware and all handlers read AllowThem from extensions.
481        // Static assets are merged unconditionally and bypass all middleware.
482        let mut csrf_protected =
483            csrf_protected.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware));
484
485        if let Some(branding) = default_branding {
486            let default = Arc::new(crate::branding::DefaultBranding(branding));
487            csrf_protected = csrf_protected.layer(axum::Extension(default));
488        }
489
490        Ok(csrf_protected
491            .merge(non_csrf)
492            .merge(crate::static_routes::router()))
493    }
494
495    /// Build routes for standalone mode. Wraps `build_inner` with the inject
496    /// shim that bridges `State<AllowThem>` into request extensions.
497    pub fn build(self, ath: &AllowThem) -> Result<Router<()>, AllRoutesError> {
498        let inner = self.build_inner()?;
499        Ok(inner.layer(axum::middleware::from_fn_with_state(
500            ath.clone(),
501            crate::cors::inject_ath_into_extensions,
502        )))
503    }
504
505    /// Build routes for SaaS mode. The tenant router injects AllowThem into
506    /// extensions before dispatching, so no inject shim is added here.
507    pub fn build_for_saas(self) -> Result<Router<()>, AllRoutesError> {
508        self.build_inner()
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use allowthem_core::AllowThemBuilder;
516
517    #[tokio::test]
518    async fn build_fails_no_routes_selected() {
519        let ath = AllowThemBuilder::new("sqlite::memory:")
520            .build()
521            .await
522            .unwrap();
523        let result = AllRoutesBuilder::new().build(&ath);
524        assert!(matches!(result, Err(AllRoutesError::NoRoutesSelected)));
525    }
526
527    #[tokio::test]
528    async fn build_fails_oauth_without_providers() {
529        let ath = AllowThemBuilder::new("sqlite::memory:")
530            .build()
531            .await
532            .unwrap();
533        let result = AllRoutesBuilder::new()
534            .base_url("http://localhost")
535            .oauth()
536            .build(&ath);
537        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
538    }
539
540    #[tokio::test]
541    async fn build_fails_password_reset_without_email_sender() {
542        let ath = AllowThemBuilder::new("sqlite::memory:")
543            .build()
544            .await
545            .unwrap();
546        let result = AllRoutesBuilder::new()
547            .base_url("http://localhost")
548            .password_reset()
549            .build(&ath);
550        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
551    }
552
553    #[tokio::test]
554    async fn build_fails_mfa_without_issuer() {
555        let ath = AllowThemBuilder::new("sqlite::memory:")
556            .build()
557            .await
558            .unwrap();
559        let result = AllRoutesBuilder::new()
560            .base_url("http://localhost")
561            .mfa()
562            .build(&ath);
563        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
564    }
565
566    #[tokio::test]
567    async fn build_fails_events_without_base_url() {
568        let ath = AllowThemBuilder::new("sqlite::memory:")
569            .csrf_key(*b"test-csrf-key-for-server-tests!!")
570            .build()
571            .await
572            .unwrap();
573        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
574        let result = AllRoutesBuilder::new().events(tx).register().build(&ath);
575        match result {
576            Err(AllRoutesError::MissingConfig(msg)) => assert!(msg.contains("base_url")),
577            other => panic!("expected MissingConfig, got {other:?}"),
578        }
579    }
580
581    #[tokio::test]
582    async fn build_fails_missing_base_url() {
583        let ath = AllowThemBuilder::new("sqlite::memory:")
584            .build()
585            .await
586            .unwrap();
587        let result = AllRoutesBuilder::new().well_known().build(&ath);
588        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
589    }
590
591    #[tokio::test]
592    async fn build_with_invalid_schema_returns_error() {
593        let ath = AllowThemBuilder::new("sqlite::memory:")
594            .csrf_key(*b"test-csrf-key-for-server-tests!!")
595            .build()
596            .await
597            .unwrap();
598        // A schema with type "array" is not a valid custom fields schema
599        let schema = serde_json::json!({"type": "array"});
600        let result = AllRoutesBuilder::new()
601            .register()
602            .custom_fields_schema(schema)
603            .build(&ath);
604        assert!(matches!(result, Err(AllRoutesError::InvalidSchema(_))));
605    }
606
607    #[test]
608    fn default_branding_setter_stores_value() {
609        use allowthem_core::applications::BrandingConfig;
610        let builder = AllRoutesBuilder::new().default_branding(BrandingConfig::new("Fixture Co"));
611        assert_eq!(
612            builder
613                .default_branding
614                .as_ref()
615                .map(|b| b.application_name.as_str()),
616            Some("Fixture Co")
617        );
618    }
619
620    #[test]
621    fn default_branding_absent_by_default() {
622        let builder = AllRoutesBuilder::new();
623        assert!(builder.default_branding.is_none());
624    }
625
626    #[tokio::test]
627    async fn build_succeeds_for_simple_routes() {
628        let ath = AllowThemBuilder::new("sqlite::memory:")
629            .csrf_key(*b"test-csrf-key-for-server-tests!!")
630            .build()
631            .await
632            .unwrap();
633        // Routes that require no extra config beyond defaults
634        let result = AllRoutesBuilder::new()
635            .login()
636            .register()
637            .logout()
638            .settings()
639            .consent()
640            .token()
641            .userinfo()
642            .build(&ath);
643        assert!(result.is_ok());
644    }
645}