Skip to main content

allowthem_server/
all_routes.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use axum::Router;
6use minijinja::Environment;
7
8use allowthem_core::{AllowThem, AuthEventSender, EmailSender, OAuthProvider};
9
10use crate::browser_templates::build_default_browser_env;
11
12/// Identifies a logical group of routes that can be selectively enabled.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum RouteGroup {
15    Login,
16    Register,
17    Logout,
18    Settings,
19    Consent,
20    PasswordReset,
21    Mfa,
22    OAuth,
23    Token,
24    UserInfo,
25    WellKnown,
26}
27
28/// Errors returned by [`AllRoutesBuilder::build`].
29#[derive(Debug)]
30pub enum AllRoutesError {
31    NoRoutesSelected,
32    MissingConfig(String),
33    InvalidSchema(String),
34}
35
36impl fmt::Display for AllRoutesError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::NoRoutesSelected => f.write_str("no route groups selected"),
40            Self::MissingConfig(msg) => write!(f, "missing config: {msg}"),
41            Self::InvalidSchema(msg) => write!(f, "invalid custom fields schema: {msg}"),
42        }
43    }
44}
45
46impl std::error::Error for AllRoutesError {}
47
48/// Builder that assembles allowthem route groups into a single
49/// [`Router<AllowThem>`] with CSRF middleware applied to the correct subset.
50pub struct AllRoutesBuilder {
51    // Shared config
52    templates: Option<Arc<Environment<'static>>>,
53    is_production: bool,
54    base_url: Option<String>,
55    email_sender: Option<Arc<dyn EmailSender>>,
56
57    // Login-specific
58    max_login_attempts: u32,
59    rate_limit_window_secs: u64,
60    oauth_providers_list: Option<Vec<String>>,
61
62    // OAuth-specific
63    oauth_provider_impls: Option<HashMap<String, Box<dyn OAuthProvider>>>,
64
65    // MFA-specific
66    mfa_issuer: Option<String>,
67
68    // Register-specific
69    custom_fields_schema: Option<serde_json::Value>,
70
71    // Event publishing (optional)
72    events_tx: Option<AuthEventSender>,
73
74    // Route selection
75    routes: HashSet<RouteGroup>,
76    all: bool,
77}
78
79impl Default for AllRoutesBuilder {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl AllRoutesBuilder {
86    pub fn new() -> Self {
87        Self {
88            templates: None,
89            is_production: false,
90            base_url: None,
91            email_sender: None,
92            max_login_attempts: 10,
93            rate_limit_window_secs: 900,
94            oauth_providers_list: None,
95            oauth_provider_impls: None,
96            mfa_issuer: None,
97            custom_fields_schema: None,
98            events_tx: None,
99            routes: HashSet::new(),
100            all: false,
101        }
102    }
103
104    // --- Shared config ---
105
106    pub fn templates(mut self, templates: Arc<Environment<'static>>) -> Self {
107        self.templates = Some(templates);
108        self
109    }
110
111    pub fn is_production(mut self, is_production: bool) -> Self {
112        self.is_production = is_production;
113        self
114    }
115
116    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
117        self.base_url = Some(base_url.into());
118        self
119    }
120
121    pub fn email_sender(mut self, sender: Arc<dyn EmailSender>) -> Self {
122        self.email_sender = Some(sender);
123        self
124    }
125
126    /// Attach a channel sender that receives lifecycle events (register, etc.).
127    ///
128    /// See `docs/superpowers/specs/2026-04-20-lifecycle-events-design.md` for
129    /// the delivery contract. Called at most once; subsequent calls overwrite.
130    pub fn events(mut self, tx: AuthEventSender) -> Self {
131        self.events_tx = Some(tx);
132        self
133    }
134
135    // --- Login config ---
136
137    pub fn max_login_attempts(mut self, max: u32) -> Self {
138        self.max_login_attempts = max;
139        self
140    }
141
142    pub fn rate_limit_window_secs(mut self, secs: u64) -> Self {
143        self.rate_limit_window_secs = secs;
144        self
145    }
146
147    pub fn oauth_providers_list(mut self, providers: Vec<String>) -> Self {
148        self.oauth_providers_list = Some(providers);
149        self
150    }
151
152    // --- OAuth config ---
153
154    pub fn oauth_providers(mut self, providers: HashMap<String, Box<dyn OAuthProvider>>) -> Self {
155        self.oauth_provider_impls = Some(providers);
156        self
157    }
158
159    // --- MFA config ---
160
161    pub fn mfa_issuer(mut self, issuer: impl Into<String>) -> Self {
162        self.mfa_issuer = Some(issuer.into());
163        self
164    }
165
166    // --- Register config ---
167
168    pub fn custom_fields_schema(mut self, schema: serde_json::Value) -> Self {
169        self.custom_fields_schema = Some(schema);
170        self
171    }
172
173    // --- Route selectors ---
174
175    pub fn login(mut self) -> Self {
176        self.routes.insert(RouteGroup::Login);
177        self
178    }
179
180    pub fn register(mut self) -> Self {
181        self.routes.insert(RouteGroup::Register);
182        self
183    }
184
185    pub fn logout(mut self) -> Self {
186        self.routes.insert(RouteGroup::Logout);
187        self
188    }
189
190    pub fn settings(mut self) -> Self {
191        self.routes.insert(RouteGroup::Settings);
192        self
193    }
194
195    pub fn consent(mut self) -> Self {
196        self.routes.insert(RouteGroup::Consent);
197        self
198    }
199
200    pub fn password_reset(mut self) -> Self {
201        self.routes.insert(RouteGroup::PasswordReset);
202        self
203    }
204
205    pub fn mfa(mut self) -> Self {
206        self.routes.insert(RouteGroup::Mfa);
207        self
208    }
209
210    pub fn oauth(mut self) -> Self {
211        self.routes.insert(RouteGroup::OAuth);
212        self
213    }
214
215    pub fn token(mut self) -> Self {
216        self.routes.insert(RouteGroup::Token);
217        self
218    }
219
220    pub fn userinfo(mut self) -> Self {
221        self.routes.insert(RouteGroup::UserInfo);
222        self
223    }
224
225    pub fn well_known(mut self) -> Self {
226        self.routes.insert(RouteGroup::WellKnown);
227        self
228    }
229
230    pub fn all_routes(mut self) -> Self {
231        self.all = true;
232        self
233    }
234
235    // --- Build ---
236
237    fn selected(&self, group: RouteGroup) -> bool {
238        self.all || self.routes.contains(&group)
239    }
240
241    pub fn build(mut self, ath: &AllowThem) -> Result<Router<AllowThem>, AllRoutesError> {
242        if !self.all && self.routes.is_empty() {
243            return Err(AllRoutesError::NoRoutesSelected);
244        }
245
246        // --- Validate required config for selected groups ---
247
248        let needs_base_url = self.selected(RouteGroup::PasswordReset)
249            || self.selected(RouteGroup::Mfa)
250            || self.selected(RouteGroup::OAuth)
251            || self.selected(RouteGroup::WellKnown);
252
253        if needs_base_url && self.base_url.is_none() {
254            return Err(AllRoutesError::MissingConfig(
255                "base_url required by selected route groups".into(),
256            ));
257        }
258
259        if self.events_tx.is_some() && self.base_url.is_none() {
260            return Err(AllRoutesError::MissingConfig(
261                "base_url required when events channel is configured".into(),
262            ));
263        }
264
265        if self.selected(RouteGroup::OAuth) && self.oauth_provider_impls.is_none() {
266            return Err(AllRoutesError::MissingConfig(
267                "oauth_providers required when oauth routes are selected".into(),
268            ));
269        }
270
271        if self.selected(RouteGroup::PasswordReset) && self.email_sender.is_none() {
272            return Err(AllRoutesError::MissingConfig(
273                "email_sender required when password_reset routes are selected".into(),
274            ));
275        }
276
277        if self.selected(RouteGroup::Mfa) && self.mfa_issuer.is_none() {
278            return Err(AllRoutesError::MissingConfig(
279                "mfa_issuer required when mfa routes are selected".into(),
280            ));
281        }
282
283        // --- Resolve defaults ---
284
285        let templates = self
286            .templates
287            .take()
288            .unwrap_or_else(build_default_browser_env);
289        let is_production = self.is_production;
290
291        // Derive oauth_providers_list from the provider map keys when not
292        // explicitly set. This avoids requiring the caller to duplicate the
293        // provider names.
294        let oauth_providers_list = self
295            .oauth_providers_list
296            .take()
297            .unwrap_or_else(|| match &self.oauth_provider_impls {
298                Some(map) => {
299                    let mut names: Vec<String> = map.keys().cloned().collect();
300                    names.sort();
301                    names
302                }
303                None => Vec::new(),
304            });
305
306        // --- CSRF-protected routes (browser routes) ---
307
308        let mut csrf_protected: Router<AllowThem> = Router::new();
309
310        if self.selected(RouteGroup::Login) {
311            csrf_protected = csrf_protected.merge(crate::login_routes::login_routes(
312                templates.clone(),
313                is_production,
314                self.max_login_attempts,
315                self.rate_limit_window_secs,
316                oauth_providers_list,
317            ));
318        }
319
320        if self.selected(RouteGroup::Register) {
321            let custom_schema = if let Some(schema) = self.custom_fields_schema.take() {
322                crate::custom_fields::validate_custom_schema(&schema)
323                    .map_err(AllRoutesError::InvalidSchema)?;
324                let validator = jsonschema::validator_for(&schema)
325                    .map_err(|e| AllRoutesError::InvalidSchema(e.to_string()))?;
326                let fields = crate::custom_fields::extract_field_descriptors(&schema);
327                Some(crate::custom_fields::CustomSchemaConfig {
328                    schema,
329                    validator,
330                    fields,
331                })
332            } else {
333                None
334            };
335            csrf_protected = csrf_protected.merge(crate::register_routes::register_routes(
336                templates.clone(),
337                is_production,
338                custom_schema,
339                self.events_tx.clone(),
340                self.base_url.clone(),
341            ));
342        }
343
344        if self.selected(RouteGroup::Logout) {
345            csrf_protected = csrf_protected.merge(crate::logout_routes::logout_routes());
346        }
347
348        if self.selected(RouteGroup::Settings) {
349            csrf_protected = csrf_protected.merge(crate::settings_routes::settings_routes(
350                templates.clone(),
351                is_production,
352            ));
353        }
354
355        if self.selected(RouteGroup::Consent) {
356            csrf_protected = csrf_protected.merge(crate::consent_routes::consent_routes(
357                templates.clone(),
358                is_production,
359            ));
360        }
361
362        if self.selected(RouteGroup::PasswordReset) {
363            let email_sender = self.email_sender.clone().expect("validated above");
364            let base_url = self.base_url.clone().expect("validated above");
365            csrf_protected = csrf_protected.merge(
366                crate::password_reset_page_routes::password_reset_page_routes(
367                    templates.clone(),
368                    is_production,
369                    email_sender,
370                    base_url,
371                ),
372            );
373        }
374
375        if self.selected(RouteGroup::Mfa) {
376            let base_url = self.base_url.clone().expect("validated above");
377            csrf_protected = csrf_protected.merge(crate::mfa_page_routes::mfa_setup_routes(
378                templates.clone(),
379                is_production,
380                base_url,
381            ));
382        }
383
384        // --- Non-CSRF routes ---
385
386        let mut non_csrf: Router<AllowThem> = Router::new();
387
388        if self.selected(RouteGroup::Mfa) {
389            non_csrf = non_csrf.merge(crate::mfa_page_routes::mfa_challenge_routes(
390                templates.clone(),
391                is_production,
392            ));
393            let issuer = self.mfa_issuer.take().expect("validated above");
394            non_csrf = non_csrf.merge(crate::mfa_routes::mfa_routes(issuer));
395        }
396
397        if self.selected(RouteGroup::OAuth) {
398            let providers = self.oauth_provider_impls.take().expect("validated above");
399            let base_url = self.base_url.clone().expect("validated above");
400            non_csrf = non_csrf.merge(crate::oauth_routes::oauth_routes(
401                providers,
402                base_url,
403                self.events_tx.clone(),
404            ));
405        }
406
407        if self.selected(RouteGroup::Token) {
408            non_csrf = non_csrf.merge(crate::token_route::token_route());
409        }
410
411        if self.selected(RouteGroup::UserInfo) {
412            non_csrf = non_csrf.merge(crate::userinfo_route::userinfo_route());
413        }
414
415        if self.selected(RouteGroup::WellKnown) {
416            let base_url = self.base_url.clone().expect("validated above");
417            non_csrf = non_csrf.merge(crate::well_known_routes::well_known_routes(base_url));
418        }
419
420        if self.selected(RouteGroup::PasswordReset) {
421            let email_sender = self.email_sender.take().expect("validated above");
422            let base_url = self.base_url.expect("validated above");
423            non_csrf = non_csrf.merge(crate::password_reset_routes::password_reset_routes(
424                email_sender,
425                base_url,
426            ));
427        }
428
429        // Apply CSRF middleware to the browser routes, then merge in the
430        // non-CSRF routes underneath. `from_fn_with_state` is required
431        // because `csrf_middleware` extracts `State<AllowThem>`.
432        Ok(csrf_protected
433            .layer(axum::middleware::from_fn_with_state(
434                ath.clone(),
435                crate::csrf::csrf_middleware,
436            ))
437            .merge(non_csrf))
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use allowthem_core::AllowThemBuilder;
445
446    #[tokio::test]
447    async fn build_fails_no_routes_selected() {
448        let ath = AllowThemBuilder::new("sqlite::memory:")
449            .build()
450            .await
451            .unwrap();
452        let result = AllRoutesBuilder::new().build(&ath);
453        assert!(matches!(result, Err(AllRoutesError::NoRoutesSelected)));
454    }
455
456    #[tokio::test]
457    async fn build_fails_oauth_without_providers() {
458        let ath = AllowThemBuilder::new("sqlite::memory:")
459            .build()
460            .await
461            .unwrap();
462        let result = AllRoutesBuilder::new()
463            .base_url("http://localhost")
464            .oauth()
465            .build(&ath);
466        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
467    }
468
469    #[tokio::test]
470    async fn build_fails_password_reset_without_email_sender() {
471        let ath = AllowThemBuilder::new("sqlite::memory:")
472            .build()
473            .await
474            .unwrap();
475        let result = AllRoutesBuilder::new()
476            .base_url("http://localhost")
477            .password_reset()
478            .build(&ath);
479        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
480    }
481
482    #[tokio::test]
483    async fn build_fails_mfa_without_issuer() {
484        let ath = AllowThemBuilder::new("sqlite::memory:")
485            .build()
486            .await
487            .unwrap();
488        let result = AllRoutesBuilder::new()
489            .base_url("http://localhost")
490            .mfa()
491            .build(&ath);
492        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
493    }
494
495    #[tokio::test]
496    async fn build_fails_events_without_base_url() {
497        let ath = AllowThemBuilder::new("sqlite::memory:")
498            .csrf_key(*b"test-csrf-key-for-server-tests!!")
499            .build()
500            .await
501            .unwrap();
502        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
503        let result = AllRoutesBuilder::new().events(tx).register().build(&ath);
504        match result {
505            Err(AllRoutesError::MissingConfig(msg)) => assert!(msg.contains("base_url")),
506            other => panic!("expected MissingConfig, got {other:?}"),
507        }
508    }
509
510    #[tokio::test]
511    async fn build_fails_missing_base_url() {
512        let ath = AllowThemBuilder::new("sqlite::memory:")
513            .build()
514            .await
515            .unwrap();
516        let result = AllRoutesBuilder::new().well_known().build(&ath);
517        assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
518    }
519
520    #[tokio::test]
521    async fn build_with_invalid_schema_returns_error() {
522        let ath = AllowThemBuilder::new("sqlite::memory:")
523            .csrf_key(*b"test-csrf-key-for-server-tests!!")
524            .build()
525            .await
526            .unwrap();
527        // A schema with type "array" is not a valid custom fields schema
528        let schema = serde_json::json!({"type": "array"});
529        let result = AllRoutesBuilder::new()
530            .register()
531            .custom_fields_schema(schema)
532            .build(&ath);
533        assert!(matches!(result, Err(AllRoutesError::InvalidSchema(_))));
534    }
535
536    #[tokio::test]
537    async fn build_succeeds_for_simple_routes() {
538        let ath = AllowThemBuilder::new("sqlite::memory:")
539            .csrf_key(*b"test-csrf-key-for-server-tests!!")
540            .build()
541            .await
542            .unwrap();
543        // Routes that require no extra config beyond defaults
544        let result = AllRoutesBuilder::new()
545            .login()
546            .register()
547            .logout()
548            .settings()
549            .consent()
550            .token()
551            .userinfo()
552            .build(&ath);
553        assert!(result.is_ok());
554    }
555}