1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use axum::Router;
6use minijinja::Environment;
7
8use allowthem_core::{AllowThem, AuthEventSender, EmailSender, OAuthProvider};
9
10use crate::browser_templates::build_default_browser_env;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum RouteGroup {
15 Login,
16 Register,
17 Logout,
18 Settings,
19 Consent,
20 PasswordReset,
21 Mfa,
22 OAuth,
23 Token,
24 UserInfo,
25 WellKnown,
26}
27
28#[derive(Debug)]
30pub enum AllRoutesError {
31 NoRoutesSelected,
32 MissingConfig(String),
33 InvalidSchema(String),
34}
35
36impl fmt::Display for AllRoutesError {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 Self::NoRoutesSelected => f.write_str("no route groups selected"),
40 Self::MissingConfig(msg) => write!(f, "missing config: {msg}"),
41 Self::InvalidSchema(msg) => write!(f, "invalid custom fields schema: {msg}"),
42 }
43 }
44}
45
46impl std::error::Error for AllRoutesError {}
47
48pub struct AllRoutesBuilder {
51 templates: Option<Arc<Environment<'static>>>,
53 is_production: bool,
54 base_url: Option<String>,
55 email_sender: Option<Arc<dyn EmailSender>>,
56
57 max_login_attempts: u32,
59 rate_limit_window_secs: u64,
60 oauth_providers_list: Option<Vec<String>>,
61
62 oauth_provider_impls: Option<HashMap<String, Box<dyn OAuthProvider>>>,
64
65 mfa_issuer: Option<String>,
67
68 custom_fields_schema: Option<serde_json::Value>,
70
71 events_tx: Option<AuthEventSender>,
73
74 routes: HashSet<RouteGroup>,
76 all: bool,
77}
78
79impl Default for AllRoutesBuilder {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl AllRoutesBuilder {
86 pub fn new() -> Self {
87 Self {
88 templates: None,
89 is_production: false,
90 base_url: None,
91 email_sender: None,
92 max_login_attempts: 10,
93 rate_limit_window_secs: 900,
94 oauth_providers_list: None,
95 oauth_provider_impls: None,
96 mfa_issuer: None,
97 custom_fields_schema: None,
98 events_tx: None,
99 routes: HashSet::new(),
100 all: false,
101 }
102 }
103
104 pub fn templates(mut self, templates: Arc<Environment<'static>>) -> Self {
107 self.templates = Some(templates);
108 self
109 }
110
111 pub fn is_production(mut self, is_production: bool) -> Self {
112 self.is_production = is_production;
113 self
114 }
115
116 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
117 self.base_url = Some(base_url.into());
118 self
119 }
120
121 pub fn email_sender(mut self, sender: Arc<dyn EmailSender>) -> Self {
122 self.email_sender = Some(sender);
123 self
124 }
125
126 pub fn events(mut self, tx: AuthEventSender) -> Self {
131 self.events_tx = Some(tx);
132 self
133 }
134
135 pub fn max_login_attempts(mut self, max: u32) -> Self {
138 self.max_login_attempts = max;
139 self
140 }
141
142 pub fn rate_limit_window_secs(mut self, secs: u64) -> Self {
143 self.rate_limit_window_secs = secs;
144 self
145 }
146
147 pub fn oauth_providers_list(mut self, providers: Vec<String>) -> Self {
148 self.oauth_providers_list = Some(providers);
149 self
150 }
151
152 pub fn oauth_providers(mut self, providers: HashMap<String, Box<dyn OAuthProvider>>) -> Self {
155 self.oauth_provider_impls = Some(providers);
156 self
157 }
158
159 pub fn mfa_issuer(mut self, issuer: impl Into<String>) -> Self {
162 self.mfa_issuer = Some(issuer.into());
163 self
164 }
165
166 pub fn custom_fields_schema(mut self, schema: serde_json::Value) -> Self {
169 self.custom_fields_schema = Some(schema);
170 self
171 }
172
173 pub fn login(mut self) -> Self {
176 self.routes.insert(RouteGroup::Login);
177 self
178 }
179
180 pub fn register(mut self) -> Self {
181 self.routes.insert(RouteGroup::Register);
182 self
183 }
184
185 pub fn logout(mut self) -> Self {
186 self.routes.insert(RouteGroup::Logout);
187 self
188 }
189
190 pub fn settings(mut self) -> Self {
191 self.routes.insert(RouteGroup::Settings);
192 self
193 }
194
195 pub fn consent(mut self) -> Self {
196 self.routes.insert(RouteGroup::Consent);
197 self
198 }
199
200 pub fn password_reset(mut self) -> Self {
201 self.routes.insert(RouteGroup::PasswordReset);
202 self
203 }
204
205 pub fn mfa(mut self) -> Self {
206 self.routes.insert(RouteGroup::Mfa);
207 self
208 }
209
210 pub fn oauth(mut self) -> Self {
211 self.routes.insert(RouteGroup::OAuth);
212 self
213 }
214
215 pub fn token(mut self) -> Self {
216 self.routes.insert(RouteGroup::Token);
217 self
218 }
219
220 pub fn userinfo(mut self) -> Self {
221 self.routes.insert(RouteGroup::UserInfo);
222 self
223 }
224
225 pub fn well_known(mut self) -> Self {
226 self.routes.insert(RouteGroup::WellKnown);
227 self
228 }
229
230 pub fn all_routes(mut self) -> Self {
231 self.all = true;
232 self
233 }
234
235 fn selected(&self, group: RouteGroup) -> bool {
238 self.all || self.routes.contains(&group)
239 }
240
241 pub fn build(mut self, ath: &AllowThem) -> Result<Router<AllowThem>, AllRoutesError> {
242 if !self.all && self.routes.is_empty() {
243 return Err(AllRoutesError::NoRoutesSelected);
244 }
245
246 let needs_base_url = self.selected(RouteGroup::PasswordReset)
249 || self.selected(RouteGroup::Mfa)
250 || self.selected(RouteGroup::OAuth)
251 || self.selected(RouteGroup::WellKnown);
252
253 if needs_base_url && self.base_url.is_none() {
254 return Err(AllRoutesError::MissingConfig(
255 "base_url required by selected route groups".into(),
256 ));
257 }
258
259 if self.events_tx.is_some() && self.base_url.is_none() {
260 return Err(AllRoutesError::MissingConfig(
261 "base_url required when events channel is configured".into(),
262 ));
263 }
264
265 if self.selected(RouteGroup::OAuth) && self.oauth_provider_impls.is_none() {
266 return Err(AllRoutesError::MissingConfig(
267 "oauth_providers required when oauth routes are selected".into(),
268 ));
269 }
270
271 if self.selected(RouteGroup::PasswordReset) && self.email_sender.is_none() {
272 return Err(AllRoutesError::MissingConfig(
273 "email_sender required when password_reset routes are selected".into(),
274 ));
275 }
276
277 if self.selected(RouteGroup::Mfa) && self.mfa_issuer.is_none() {
278 return Err(AllRoutesError::MissingConfig(
279 "mfa_issuer required when mfa routes are selected".into(),
280 ));
281 }
282
283 let templates = self
286 .templates
287 .take()
288 .unwrap_or_else(build_default_browser_env);
289 let is_production = self.is_production;
290
291 let oauth_providers_list = self
295 .oauth_providers_list
296 .take()
297 .unwrap_or_else(|| match &self.oauth_provider_impls {
298 Some(map) => {
299 let mut names: Vec<String> = map.keys().cloned().collect();
300 names.sort();
301 names
302 }
303 None => Vec::new(),
304 });
305
306 let mut csrf_protected: Router<AllowThem> = Router::new();
309
310 if self.selected(RouteGroup::Login) {
311 csrf_protected = csrf_protected.merge(crate::login_routes::login_routes(
312 templates.clone(),
313 is_production,
314 self.max_login_attempts,
315 self.rate_limit_window_secs,
316 oauth_providers_list,
317 ));
318 }
319
320 if self.selected(RouteGroup::Register) {
321 let custom_schema = if let Some(schema) = self.custom_fields_schema.take() {
322 crate::custom_fields::validate_custom_schema(&schema)
323 .map_err(AllRoutesError::InvalidSchema)?;
324 let validator = jsonschema::validator_for(&schema)
325 .map_err(|e| AllRoutesError::InvalidSchema(e.to_string()))?;
326 let fields = crate::custom_fields::extract_field_descriptors(&schema);
327 Some(crate::custom_fields::CustomSchemaConfig {
328 schema,
329 validator,
330 fields,
331 })
332 } else {
333 None
334 };
335 csrf_protected = csrf_protected.merge(crate::register_routes::register_routes(
336 templates.clone(),
337 is_production,
338 custom_schema,
339 self.events_tx.clone(),
340 self.base_url.clone(),
341 ));
342 }
343
344 if self.selected(RouteGroup::Logout) {
345 csrf_protected = csrf_protected.merge(crate::logout_routes::logout_routes());
346 }
347
348 if self.selected(RouteGroup::Settings) {
349 csrf_protected = csrf_protected.merge(crate::settings_routes::settings_routes(
350 templates.clone(),
351 is_production,
352 ));
353 }
354
355 if self.selected(RouteGroup::Consent) {
356 csrf_protected = csrf_protected.merge(crate::consent_routes::consent_routes(
357 templates.clone(),
358 is_production,
359 ));
360 }
361
362 if self.selected(RouteGroup::PasswordReset) {
363 let email_sender = self.email_sender.clone().expect("validated above");
364 let base_url = self.base_url.clone().expect("validated above");
365 csrf_protected = csrf_protected.merge(
366 crate::password_reset_page_routes::password_reset_page_routes(
367 templates.clone(),
368 is_production,
369 email_sender,
370 base_url,
371 ),
372 );
373 }
374
375 if self.selected(RouteGroup::Mfa) {
376 let base_url = self.base_url.clone().expect("validated above");
377 csrf_protected = csrf_protected.merge(crate::mfa_page_routes::mfa_setup_routes(
378 templates.clone(),
379 is_production,
380 base_url,
381 ));
382 }
383
384 let mut non_csrf: Router<AllowThem> = Router::new();
387
388 if self.selected(RouteGroup::Mfa) {
389 non_csrf = non_csrf.merge(crate::mfa_page_routes::mfa_challenge_routes(
390 templates.clone(),
391 is_production,
392 ));
393 let issuer = self.mfa_issuer.take().expect("validated above");
394 non_csrf = non_csrf.merge(crate::mfa_routes::mfa_routes(issuer));
395 }
396
397 if self.selected(RouteGroup::OAuth) {
398 let providers = self.oauth_provider_impls.take().expect("validated above");
399 let base_url = self.base_url.clone().expect("validated above");
400 non_csrf = non_csrf.merge(crate::oauth_routes::oauth_routes(
401 providers,
402 base_url,
403 self.events_tx.clone(),
404 ));
405 }
406
407 if self.selected(RouteGroup::Token) {
408 non_csrf = non_csrf.merge(crate::token_route::token_route());
409 }
410
411 if self.selected(RouteGroup::UserInfo) {
412 non_csrf = non_csrf.merge(crate::userinfo_route::userinfo_route());
413 }
414
415 if self.selected(RouteGroup::WellKnown) {
416 let base_url = self.base_url.clone().expect("validated above");
417 non_csrf = non_csrf.merge(crate::well_known_routes::well_known_routes(base_url));
418 }
419
420 if self.selected(RouteGroup::PasswordReset) {
421 let email_sender = self.email_sender.take().expect("validated above");
422 let base_url = self.base_url.expect("validated above");
423 non_csrf = non_csrf.merge(crate::password_reset_routes::password_reset_routes(
424 email_sender,
425 base_url,
426 ));
427 }
428
429 Ok(csrf_protected
433 .layer(axum::middleware::from_fn_with_state(
434 ath.clone(),
435 crate::csrf::csrf_middleware,
436 ))
437 .merge(non_csrf))
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use allowthem_core::AllowThemBuilder;
445
446 #[tokio::test]
447 async fn build_fails_no_routes_selected() {
448 let ath = AllowThemBuilder::new("sqlite::memory:")
449 .build()
450 .await
451 .unwrap();
452 let result = AllRoutesBuilder::new().build(&ath);
453 assert!(matches!(result, Err(AllRoutesError::NoRoutesSelected)));
454 }
455
456 #[tokio::test]
457 async fn build_fails_oauth_without_providers() {
458 let ath = AllowThemBuilder::new("sqlite::memory:")
459 .build()
460 .await
461 .unwrap();
462 let result = AllRoutesBuilder::new()
463 .base_url("http://localhost")
464 .oauth()
465 .build(&ath);
466 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
467 }
468
469 #[tokio::test]
470 async fn build_fails_password_reset_without_email_sender() {
471 let ath = AllowThemBuilder::new("sqlite::memory:")
472 .build()
473 .await
474 .unwrap();
475 let result = AllRoutesBuilder::new()
476 .base_url("http://localhost")
477 .password_reset()
478 .build(&ath);
479 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
480 }
481
482 #[tokio::test]
483 async fn build_fails_mfa_without_issuer() {
484 let ath = AllowThemBuilder::new("sqlite::memory:")
485 .build()
486 .await
487 .unwrap();
488 let result = AllRoutesBuilder::new()
489 .base_url("http://localhost")
490 .mfa()
491 .build(&ath);
492 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
493 }
494
495 #[tokio::test]
496 async fn build_fails_events_without_base_url() {
497 let ath = AllowThemBuilder::new("sqlite::memory:")
498 .csrf_key(*b"test-csrf-key-for-server-tests!!")
499 .build()
500 .await
501 .unwrap();
502 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
503 let result = AllRoutesBuilder::new().events(tx).register().build(&ath);
504 match result {
505 Err(AllRoutesError::MissingConfig(msg)) => assert!(msg.contains("base_url")),
506 other => panic!("expected MissingConfig, got {other:?}"),
507 }
508 }
509
510 #[tokio::test]
511 async fn build_fails_missing_base_url() {
512 let ath = AllowThemBuilder::new("sqlite::memory:")
513 .build()
514 .await
515 .unwrap();
516 let result = AllRoutesBuilder::new().well_known().build(&ath);
517 assert!(matches!(result, Err(AllRoutesError::MissingConfig(_))));
518 }
519
520 #[tokio::test]
521 async fn build_with_invalid_schema_returns_error() {
522 let ath = AllowThemBuilder::new("sqlite::memory:")
523 .csrf_key(*b"test-csrf-key-for-server-tests!!")
524 .build()
525 .await
526 .unwrap();
527 let schema = serde_json::json!({"type": "array"});
529 let result = AllRoutesBuilder::new()
530 .register()
531 .custom_fields_schema(schema)
532 .build(&ath);
533 assert!(matches!(result, Err(AllRoutesError::InvalidSchema(_))));
534 }
535
536 #[tokio::test]
537 async fn build_succeeds_for_simple_routes() {
538 let ath = AllowThemBuilder::new("sqlite::memory:")
539 .csrf_key(*b"test-csrf-key-for-server-tests!!")
540 .build()
541 .await
542 .unwrap();
543 let result = AllRoutesBuilder::new()
545 .login()
546 .register()
547 .logout()
548 .settings()
549 .consent()
550 .token()
551 .userinfo()
552 .build(&ath);
553 assert!(result.is_ok());
554 }
555}