use std::collections::HashMap;
use std::sync::Arc;
use axum::Form;
use axum::Router;
use axum::extract::{Extension, Query};
use axum::http::header::{COOKIE, USER_AGENT};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::get;
use axum_htmx::{HxBoosted, HxRequest};
use chrono::Utc;
use serde::Deserialize;
use allowthem_core::applications::BrandingConfig;
use allowthem_core::types::ClientId;
use allowthem_core::{
AllowThem, AuditEvent, AuthError, Email, EventContext, LifecycleEvent, LifecycleEventSender,
RegisteredEvent, RegistrationSource, Username, generate_token, hash_token,
};
use crate::auth_views::{RegisterView, register_fragment, register_page};
use crate::branding::{DefaultBranding, default_branding_ref, resolve_branding};
use crate::browser_error::BrowserError;
use crate::csrf::CsrfToken;
use crate::custom_fields::{
CustomFieldDescriptor, CustomSchemaConfig, extract_and_coerce_custom_data,
format_validation_errors,
};
use crate::events::{client_ip, publish};
const MIN_PASSWORD_LEN: usize = 8;
#[derive(Clone)]
struct RegisterConfig {
is_production: bool,
custom_schema: Option<Arc<CustomSchemaConfig>>,
events_tx: Option<LifecycleEventSender>,
base_url: Option<String>,
oauth_providers: Vec<String>,
public_registration: bool,
}
#[derive(Deserialize)]
struct RegisterQuery {
#[serde(default)]
client_id: Option<ClientId>,
#[serde(default)]
token: Option<String>,
}
#[allow(clippy::too_many_arguments)]
async fn get_register(
Extension(ath): Extension<AllowThem>,
Extension(config): Extension<RegisterConfig>,
default_branding: Option<Extension<Arc<DefaultBranding>>>,
headers: HeaderMap,
csrf: CsrfToken,
Query(query): Query<RegisterQuery>,
HxBoosted(boosted): HxBoosted,
HxRequest(request): HxRequest,
) -> Result<Response, BrowserError> {
if is_authenticated(&ath, &headers).await {
return Ok((StatusCode::SEE_OTHER, [(axum::http::header::LOCATION, "/")]).into_response());
}
let default = default_branding_ref(&default_branding);
let branding = resolve_branding(&ath, query.client_id.as_ref(), default).await;
let custom_fields = config
.custom_schema
.as_ref()
.map(|s| s.fields.as_slice())
.unwrap_or(&[]);
let empty: HashMap<String, String> = HashMap::new();
let invitation = match invitation_for_token(&ath, query.token.as_deref()).await {
Ok(invitation) => invitation,
Err(error) => {
let params = RegisterFormParams {
csrf_token: csrf.as_str(),
email: "",
username: "",
error: &error,
client_id: query.client_id.as_ref(),
branding: branding.as_ref(),
custom_fields,
custom_values: &empty,
token: query.token.as_deref(),
email_readonly: false,
registration_disabled: !config.public_registration,
};
return render_register_response(&config, params, request && !boosted);
}
};
let invited_email = invitation
.as_ref()
.and_then(|inv| inv.email.as_ref())
.map(|email| email.as_str())
.unwrap_or("");
let params = RegisterFormParams {
csrf_token: csrf.as_str(),
email: invited_email,
username: "",
error: if !config.public_registration && invitation.is_none() {
"Registration is disabled. Use an invitation link."
} else {
""
},
client_id: query.client_id.as_ref(),
branding: branding.as_ref(),
custom_fields,
custom_values: &empty,
token: query.token.as_deref(),
email_readonly: !invited_email.is_empty(),
registration_disabled: !config.public_registration && invitation.is_none(),
};
render_register_response(&config, params, request && !boosted)
}
async fn post_register(
Extension(ath): Extension<AllowThem>,
Extension(config): Extension<RegisterConfig>,
default_branding: Option<Extension<Arc<DefaultBranding>>>,
csrf: CsrfToken,
Query(query): Query<RegisterQuery>,
headers: HeaderMap,
Form(form): Form<HashMap<String, String>>,
) -> Result<Response, BrowserError> {
let default = default_branding_ref(&default_branding);
let branding = resolve_branding(&ath, query.client_id.as_ref(), default).await;
let cid = query.client_id.as_ref();
let br = branding.as_ref();
let custom_fields = config
.custom_schema
.as_ref()
.map(|s| s.fields.as_slice())
.unwrap_or(&[]);
let token_raw = form
.get("token")
.map(String::as_str)
.filter(|v| !v.is_empty());
let invitation = match invitation_for_token(&ath, token_raw).await {
Ok(invitation) => invitation,
Err(error) => {
return render_form_error(
&config,
&csrf,
form.get("email").map(String::as_str).unwrap_or(""),
form.get("username").map(String::as_str).unwrap_or(""),
&error,
cid,
br,
custom_fields,
&form,
token_raw,
false,
!config.public_registration,
);
}
};
let email_raw = form.get("email").map(String::as_str).unwrap_or("");
let password = form.get("password").map(String::as_str).unwrap_or("");
let password_confirm = form
.get("password_confirm")
.map(String::as_str)
.unwrap_or("");
let username_raw = form.get("username").map(String::as_str).unwrap_or("");
let email_readonly = invitation
.as_ref()
.and_then(|inv| inv.email.as_ref())
.is_some();
if !config.public_registration && invitation.is_none() {
return render_form_error(
&config,
&csrf,
email_raw,
username_raw,
"Registration is disabled. Use an invitation link.",
cid,
br,
custom_fields,
&form,
token_raw,
email_readonly,
true,
);
}
if password != password_confirm {
return render_form_error(
&config,
&csrf,
email_raw,
username_raw,
"Passwords do not match",
cid,
br,
custom_fields,
&form,
token_raw,
email_readonly,
false,
);
}
if password.len() < MIN_PASSWORD_LEN {
return render_form_error(
&config,
&csrf,
email_raw,
username_raw,
"Password must be at least 8 characters",
cid,
br,
custom_fields,
&form,
token_raw,
email_readonly,
false,
);
}
let email = match Email::new(email_raw.to_string()) {
Ok(e) => e,
Err(_) => {
return render_form_error(
&config,
&csrf,
email_raw,
username_raw,
"Invalid email address",
cid,
br,
custom_fields,
&form,
token_raw,
email_readonly,
false,
);
}
};
if let Some(invitation) = &invitation
&& let Some(expected) = invitation.email.as_ref()
&& !expected.as_str().eq_ignore_ascii_case(email.as_str())
{
return render_form_error(
&config,
&csrf,
expected.as_str(),
username_raw,
"This invitation is for a different email address",
cid,
br,
custom_fields,
&form,
token_raw,
true,
false,
);
}
let trimmed = username_raw.trim();
let username = if trimmed.is_empty() {
None
} else {
Some(Username::new(trimmed))
};
let custom_data = if let Some(schema_config) = &config.custom_schema {
let coerced = extract_and_coerce_custom_data(&form, &schema_config.schema);
let errors: Vec<_> = schema_config.validator.iter_errors(&coerced).collect();
if !errors.is_empty() {
let field_errors = format_validation_errors(&errors);
let error_msg = field_errors
.iter()
.map(|(field, msg)| {
if field.is_empty() {
msg.clone()
} else {
format!("{field}: {msg}")
}
})
.collect::<Vec<_>>()
.join("; ");
return render_form_error(
&config,
&csrf,
email_raw,
username_raw,
&error_msg,
cid,
br,
custom_fields,
&form,
token_raw,
email_readonly,
false,
);
}
Some(coerced)
} else {
None
};
let user = match ath
.create_user(email, password, username, custom_data.as_ref())
.await
{
Ok(u) => u,
Err(AuthError::Conflict(ref msg)) if msg.contains("email") => {
return render_form_error(
&config,
&csrf,
email_raw,
username_raw,
"Registration could not be completed. If you already have an account, try logging in.",
cid,
br,
custom_fields,
&form,
token_raw,
email_readonly,
false,
);
}
Err(AuthError::Conflict(ref msg)) if msg.contains("username") => {
return render_form_error(
&config,
&csrf,
email_raw,
username_raw,
"This username is already taken",
cid,
br,
custom_fields,
&form,
token_raw,
email_readonly,
false,
);
}
Err(e) => return Err(BrowserError::Auth(e)),
};
if invitation.is_some() {
ath.db().set_email_verified(user.id, true).await?;
}
let token = generate_token();
let token_hash = hash_token(&token);
let expires_at = Utc::now() + ath.session_config().ttl;
let ip = client_ip(&headers);
let ua = headers.get(USER_AGENT).and_then(|v| v.to_str().ok());
ath.db()
.create_session(user.id, token_hash, ip.as_deref(), ua, expires_at)
.await?;
if let Err(e) = ath
.db()
.log_audit(
AuditEvent::Register,
Some(&user.id),
None,
ip.as_deref(),
ua,
None,
)
.await
{
tracing::error!(error = %e, "failed to log registration audit event");
}
if let Some(invitation) = &invitation {
ath.db().consume_invitation(invitation.id).await?;
}
publish(config.events_tx.as_ref(), || {
let source = invitation
.as_ref()
.map(|invitation| RegistrationSource::Invitation {
email: invitation
.email
.as_ref()
.map(|email| email.as_str().to_string()),
metadata: invitation.metadata.clone(),
})
.unwrap_or(RegistrationSource::Password);
LifecycleEvent::Registered(RegisteredEvent::new(
user.clone(),
source,
EventContext::new(
ip.clone(),
ua.map(str::to_owned),
config.base_url.clone().unwrap_or_default(),
Utc::now(),
),
))
});
let cookie = ath.session_cookie(&token);
Ok((
StatusCode::SEE_OTHER,
[
(axum::http::header::SET_COOKIE, cookie),
(axum::http::header::LOCATION, "/".to_string()),
],
)
.into_response())
}
struct RegisterFormParams<'a> {
csrf_token: &'a str,
email: &'a str,
username: &'a str,
error: &'a str,
client_id: Option<&'a ClientId>,
branding: Option<&'a BrandingConfig>,
custom_fields: &'a [CustomFieldDescriptor],
custom_values: &'a HashMap<String, String>,
token: Option<&'a str>,
email_readonly: bool,
registration_disabled: bool,
}
async fn invitation_for_token(
ath: &AllowThem,
token: Option<&str>,
) -> Result<Option<allowthem_core::Invitation>, String> {
let Some(token) = token.map(str::trim).filter(|token| !token.is_empty()) else {
return Ok(None);
};
match ath.db().validate_invitation(token).await {
Ok(Some(invitation)) => Ok(Some(invitation)),
Ok(None) => Err("This invitation link is invalid or has expired.".to_string()),
Err(error) => {
tracing::error!(%error, "failed to validate invitation token");
Err("This invitation link could not be validated.".to_string())
}
}
}
fn render_register_response(
config: &RegisterConfig,
params: RegisterFormParams<'_>,
fragment: bool,
) -> Result<Response, BrowserError> {
if fragment {
return Ok(render_register_fragment(config, params)?.into_response());
}
Ok(render_register_form(config, params)?.into_response())
}
fn render_register_form(
config: &RegisterConfig,
params: RegisterFormParams<'_>,
) -> Result<axum::response::Html<String>, BrowserError> {
let custom_values = custom_values_map(params.custom_values);
register_page(&RegisterView {
csrf_token: params.csrf_token,
email: params.email,
username: params.username,
error: params.error,
client_id: params.client_id.map(|c| c.as_str()),
branding: params.branding,
custom_fields: params.custom_fields,
custom_values: &custom_values,
token: params.token,
email_readonly: params.email_readonly,
registration_disabled: params.registration_disabled,
oauth_providers: &config.oauth_providers,
is_production: config.is_production,
})
}
fn render_register_fragment(
config: &RegisterConfig,
params: RegisterFormParams<'_>,
) -> Result<axum::response::Html<String>, BrowserError> {
let custom_values = custom_values_map(params.custom_values);
register_fragment(&RegisterView {
csrf_token: params.csrf_token,
email: params.email,
username: params.username,
error: params.error,
client_id: params.client_id.map(|c| c.as_str()),
branding: params.branding,
custom_fields: params.custom_fields,
custom_values: &custom_values,
token: params.token,
email_readonly: params.email_readonly,
registration_disabled: params.registration_disabled,
oauth_providers: &config.oauth_providers,
is_production: config.is_production,
})
}
fn custom_values_map(values: &HashMap<String, String>) -> HashMap<&str, &str> {
values
.iter()
.filter_map(|(key, value)| {
key.strip_prefix("custom_data[")
.and_then(|field| field.strip_suffix(']'))
.map(|field| (field, value.as_str()))
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn render_form_error(
config: &RegisterConfig,
csrf: &CsrfToken,
email: &str,
username: &str,
error: &str,
client_id: Option<&ClientId>,
branding: Option<&BrandingConfig>,
custom_fields: &[CustomFieldDescriptor],
form_data: &HashMap<String, String>,
token: Option<&str>,
email_readonly: bool,
registration_disabled: bool,
) -> Result<Response, BrowserError> {
let html = render_register_form(
config,
RegisterFormParams {
csrf_token: csrf.as_str(),
email,
username,
error,
client_id,
branding,
custom_fields,
custom_values: form_data,
token,
email_readonly,
registration_disabled,
},
)?;
Ok(html.into_response())
}
async fn is_authenticated(ath: &AllowThem, headers: &HeaderMap) -> bool {
let Some(cookie_header) = headers.get(COOKIE).and_then(|v| v.to_str().ok()) else {
return false;
};
let Some(token) = ath.parse_session_cookie(cookie_header) else {
return false;
};
let ttl = ath.session_config().ttl;
ath.db()
.validate_session(&token, ttl)
.await
.unwrap_or(None)
.is_some()
}
pub fn register_routes(
is_production: bool,
custom_schema: Option<CustomSchemaConfig>,
events_tx: Option<LifecycleEventSender>,
base_url: Option<String>,
oauth_providers: Vec<String>,
public_registration: bool,
) -> Router<()> {
let cfg = RegisterConfig {
is_production,
custom_schema: custom_schema.map(Arc::new),
events_tx,
base_url,
oauth_providers,
public_registration,
};
Router::new()
.route("/register", get(get_register).post(post_register))
.layer(Extension(cfg))
}
#[cfg(test)]
mod tests {
use axum::Router;
use axum::body::Body;
use axum::http::{Request, StatusCode, header};
use serde_json::json;
use tower::ServiceExt;
use allowthem_core::applications::CreateApplicationParams;
use allowthem_core::types::ClientType;
use allowthem_core::{
AllowThem, AllowThemBuilder, AuditEvent, Email, LifecycleEvent, RegistrationSource,
Username, parse_session_cookie,
};
use crate::custom_fields::{CustomSchemaConfig, extract_field_descriptors};
use super::{RegisterConfig, RegisterFormParams, register_routes, render_register_fragment};
async fn setup() -> (AllowThem, RegisterConfig) {
let ath = AllowThemBuilder::new("sqlite::memory:")
.cookie_secure(false)
.csrf_key(*b"test-csrf-key-for-binary-tests!!")
.build()
.await
.unwrap();
let config = RegisterConfig {
is_production: false,
custom_schema: None,
events_tx: None,
base_url: None,
oauth_providers: Vec::new(),
public_registration: true,
};
(ath, config)
}
fn test_app(ath: AllowThem, config: RegisterConfig) -> Router {
register_routes(
config.is_production,
config.custom_schema.as_ref().map(|arc| CustomSchemaConfig {
schema: arc.schema.clone(),
validator: arc.validator.clone(),
fields: arc.fields.clone(),
}),
config.events_tx.clone(),
config.base_url.clone(),
config.oauth_providers.clone(),
config.public_registration,
)
.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware))
.layer(axum::middleware::from_fn_with_state(
ath.clone(),
crate::cors::inject_ath_into_extensions,
))
}
fn test_app_with_schema(ath: AllowThem, config: RegisterConfig) -> Router {
let schema = json!({
"type": "object",
"required": ["company"],
"properties": {
"company": {
"type": "string",
"title": "Company Name",
"minLength": 1
},
"age": {
"type": "integer",
"minimum": 0,
"maximum": 120,
"default": 21
},
"newsletter": {
"type": "boolean"
}
}
});
let validator = jsonschema::validator_for(&schema).expect("valid schema");
let fields = extract_field_descriptors(&schema);
let schema_config = CustomSchemaConfig {
schema,
validator,
fields,
};
register_routes(
config.is_production,
Some(schema_config),
config.events_tx.clone(),
config.base_url.clone(),
config.oauth_providers.clone(),
config.public_registration,
)
.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware))
.layer(axum::middleware::from_fn_with_state(
ath.clone(),
crate::cors::inject_ath_into_extensions,
))
}
async fn get_csrf_token(app: &Router) -> String {
let req = Request::builder()
.uri("/register")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
let set_cookie = resp
.headers()
.get(header::SET_COOKIE)
.unwrap()
.to_str()
.unwrap()
.to_string();
set_cookie
.split(';')
.next()
.unwrap()
.split('=')
.nth(1)
.unwrap()
.to_string()
}
fn register_request(
csrf: &str,
email: &str,
password: &str,
confirm: &str,
username: &str,
) -> Request<Body> {
let enc = |s: &str| s.replace('@', "%40");
let body = format!(
"csrf_token={}&email={}&password={}&password_confirm={}&username={}",
csrf,
enc(email),
enc(password),
enc(confirm),
enc(username),
);
Request::builder()
.method("POST")
.uri("/register")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(header::COOKIE, format!("csrf_pre={}", csrf))
.body(Body::from(body))
.unwrap()
}
fn register_request_with_custom(
csrf: &str,
email: &str,
password: &str,
confirm: &str,
username: &str,
custom_fields: &[(&str, &str)],
) -> Request<Body> {
let enc = |s: &str| s.replace('@', "%40");
let mut body = format!(
"csrf_token={}&email={}&password={}&password_confirm={}&username={}",
csrf,
enc(email),
enc(password),
enc(confirm),
enc(username),
);
for (key, value) in custom_fields {
body.push_str(&format!("&custom_data%5B{}%5D={}", key, enc(value)));
}
Request::builder()
.method("POST")
.uri("/register")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.header(header::COOKIE, format!("csrf_pre={}", csrf))
.body(Body::from(body))
.unwrap()
}
async fn body_string(resp: axum::http::Response<Body>) -> String {
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
#[tokio::test]
async fn get_register_renders_form() {
let (ath, config) = setup().await;
let app = test_app(ath, config);
let req = Request::builder()
.uri("/register")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let html = body_string(resp).await;
assert!(html.contains("<form"));
assert!(html.contains("csrf_token"));
assert!(html.contains("name=\"email\""));
assert!(html.contains("name=\"password\""));
assert!(html.contains("name=\"password_confirm\""));
}
#[tokio::test]
async fn post_register_success_redirects() {
let (ath, config) = setup().await;
let app = test_app(ath, config);
let csrf = get_csrf_token(&app).await;
let req = register_request(&csrf, "test@example.com", "password123", "password123", "");
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
assert_eq!(resp.headers().get("location").unwrap(), "/");
let set_cookie = resp
.headers()
.get(header::SET_COOKIE)
.unwrap()
.to_str()
.unwrap();
assert!(set_cookie.contains("allowthem_session"));
}
#[tokio::test]
async fn post_register_creates_user() {
let (ath, config) = setup().await;
let app = test_app(ath.clone(), config);
let csrf = get_csrf_token(&app).await;
let req = register_request(
&csrf,
"new@example.com",
"password123",
"password123",
"myuser",
);
app.oneshot(req).await.unwrap();
let email = Email::new("new@example.com".into()).unwrap();
let user = ath.db().get_user_by_email(&email).await.unwrap();
assert_eq!(user.email, email);
assert_eq!(user.username.as_ref().map(|u| u.as_str()), Some("myuser"));
}
#[tokio::test]
async fn post_register_creates_session() {
let (ath, config) = setup().await;
let app = test_app(ath.clone(), config);
let csrf = get_csrf_token(&app).await;
let req = register_request(&csrf, "sess@example.com", "password123", "password123", "");
let resp = app.oneshot(req).await.unwrap();
let set_cookie = resp
.headers()
.get(header::SET_COOKIE)
.unwrap()
.to_str()
.unwrap();
let token = parse_session_cookie(set_cookie, "allowthem_session")
.expect("session cookie should be present");
let session = ath.db().lookup_session(&token).await.unwrap();
assert!(session.is_some(), "session should exist in DB");
}
#[tokio::test]
async fn post_register_logs_audit() {
let (ath, config) = setup().await;
let app = test_app(ath.clone(), config);
let csrf = get_csrf_token(&app).await;
let req = register_request(&csrf, "audit@example.com", "password123", "password123", "");
app.oneshot(req).await.unwrap();
let entries = ath.db().get_audit_log(None, 10, 0).await.unwrap();
let register_entry = entries
.iter()
.find(|e| e.event_type == AuditEvent::Register);
assert!(
register_entry.is_some(),
"register audit event should be recorded"
);
}
#[tokio::test]
async fn post_register_password_mismatch() {
let (ath, config) = setup().await;
let app = test_app(ath, config);
let csrf = get_csrf_token(&app).await;
let req = register_request(
&csrf,
"mismatch@example.com",
"password123",
"different456",
"",
);
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let html = body_string(resp).await;
assert!(html.contains("Passwords do not match"));
assert!(html.contains("mismatch@example.com"));
}
#[tokio::test]
async fn post_register_password_too_short() {
let (ath, config) = setup().await;
let app = test_app(ath, config);
let csrf = get_csrf_token(&app).await;
let req = register_request(&csrf, "short@example.com", "abc", "abc", "");
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let html = body_string(resp).await;
assert!(html.contains("Password must be at least 8 characters"));
}
#[tokio::test]
async fn post_register_invalid_email() {
let (ath, config) = setup().await;
let app = test_app(ath, config);
let csrf = get_csrf_token(&app).await;
let req = register_request(&csrf, "not-an-email", "password123", "password123", "");
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let html = body_string(resp).await;
assert!(html.contains("Invalid email address"));
}
#[tokio::test]
async fn post_register_duplicate_email() {
let (ath, config) = setup().await;
let email = Email::new("dupe@example.com".into()).unwrap();
ath.db()
.create_user(email, "existing123", None, None)
.await
.unwrap();
let app = test_app(ath, config);
let csrf = get_csrf_token(&app).await;
let req = register_request(&csrf, "dupe@example.com", "password123", "password123", "");
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let html = body_string(resp).await;
assert!(html.contains("Registration could not be completed"));
}
#[tokio::test]
async fn post_register_duplicate_username() {
let (ath, config) = setup().await;
let email = Email::new("first@example.com".into()).unwrap();
ath.db()
.create_user(email, "existing123", Some(Username::new("taken")), None)
.await
.unwrap();
let app = test_app(ath, config);
let csrf = get_csrf_token(&app).await;
let req = register_request(
&csrf,
"second@example.com",
"password123",
"password123",
"taken",
);
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let html = body_string(resp).await;
assert!(html.contains("This username is already taken"));
}
#[tokio::test]
async fn post_register_session_cookie_authenticates() {
let (ath, config) = setup().await;
let app = test_app(ath.clone(), config);
let csrf = get_csrf_token(&app).await;
let req = register_request(&csrf, "auth@example.com", "password123", "password123", "");
let resp = app.oneshot(req).await.unwrap();
let set_cookie = resp
.headers()
.get(header::SET_COOKIE)
.unwrap()
.to_str()
.unwrap();
let token = parse_session_cookie(set_cookie, "allowthem_session")
.expect("session cookie should be present");
let ttl = ath.session_config().ttl;
let session_result = ath.db().validate_session(&token, ttl).await.unwrap();
assert!(
session_result.is_some(),
"session cookie issued at registration should be valid"
);
}
#[tokio::test]
async fn get_register_logged_in_redirects_to_root() {
use allowthem_core::{generate_token, hash_token};
use chrono::{Duration, Utc};
let (ath, config) = setup().await;
let email = Email::new("loggedin@example.com".into()).unwrap();
let user = ath
.db()
.create_user(email, "password123", None, None)
.await
.unwrap();
let token = generate_token();
let token_hash = hash_token(&token);
ath.db()
.create_session(
user.id,
token_hash,
None,
None,
Utc::now() + Duration::hours(24),
)
.await
.unwrap();
let set_cookie = ath.session_cookie(&token);
let cookie_value = set_cookie.split(';').next().unwrap().to_string();
let app = test_app(ath, config);
let req = Request::builder()
.uri("/register")
.header(header::COOKIE, cookie_value)
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
assert_eq!(resp.headers().get("location").unwrap(), "/");
}
#[tokio::test]
async fn post_register_without_csrf_returns_403() {
let (ath, config) = setup().await;
let app = test_app(ath, config);
let body = "email=test%40example.com&password=password123&password_confirm=password123";
let req = Request::builder()
.method("POST")
.uri("/register")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn register_with_client_id_shows_branding() {
let (ath, config) = setup().await;
let (app, _) = ath
.db()
.create_application(CreateApplicationParams {
name: "BrandedRegApp".into(),
client_type: ClientType::Confidential,
redirect_uris: vec!["https://example.com/cb".into()],
is_trusted: false,
created_by: None,
logo_url: Some("https://cdn.example.com/logo.png".into()),
primary_color: Some("#ff6600".into()),
accent_hex: None,
accent_ink: None,
forced_mode: None,
font_css_url: None,
font_family: None,
splash_text: None,
splash_image_url: None,
splash_primitive: None,
splash_url: None,
shader_cell_scale: None,
})
.await
.unwrap();
let router = test_app(ath, config);
let req = Request::builder()
.uri(&format!("/register?client_id={}", app.client_id))
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let html = String::from_utf8(body.to_vec()).unwrap();
assert!(html.contains("BrandedRegApp"), "should show app name");
assert!(
html.contains("<title>Create account — BrandedRegApp</title>"),
"default title brand should use the application name"
);
assert!(
html.contains("--accent: #ff6600"),
"primary_color should flow to --accent"
);
assert!(
html.contains("--accent-ink:"),
"accent_ink should be emitted in template"
);
}
#[tokio::test]
async fn register_without_client_id_shows_default() {
let (ath, config) = setup().await;
let router = test_app(ath, config);
let req = Request::builder()
.uri("/register")
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let html = String::from_utf8(body.to_vec()).unwrap();
assert!(!html.contains("<img"), "no logo without client_id");
assert!(
html.contains("--accent: #ffffff"),
"should have default white accent"
);
assert!(
html.contains("--accent-ink: #000000"),
"should have default black ink"
);
}
#[tokio::test]
async fn register_sign_in_link_carries_client_id() {
let (ath, config) = setup().await;
let (app, _) = ath
.db()
.create_application(CreateApplicationParams {
name: "LinkApp".into(),
client_type: ClientType::Confidential,
redirect_uris: vec!["https://example.com/cb".into()],
is_trusted: false,
created_by: None,
logo_url: None,
primary_color: None,
accent_hex: None,
accent_ink: None,
forced_mode: None,
font_css_url: None,
font_family: None,
splash_text: None,
splash_image_url: None,
splash_primitive: None,
splash_url: None,
shader_cell_scale: None,
})
.await
.unwrap();
let router = test_app(ath, config);
let req = Request::builder()
.uri(&format!("/register?client_id={}", app.client_id))
.body(Body::empty())
.unwrap();
let resp = router.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let html = String::from_utf8(body.to_vec()).unwrap();
let id = app.client_id.as_str();
let login_link = format!("/login?client_id={id}");
assert!(
html.contains(&login_link),
"sign-in link should carry client_id"
);
}
#[tokio::test]
async fn register_without_schema_works_as_before() {
let (ath, config) = setup().await;
let app = test_app(ath.clone(), config);
let csrf = get_csrf_token(&app).await;
let req = register_request(
&csrf,
"noschema@example.com",
"password123",
"password123",
"",
);
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
let email = Email::new("noschema@example.com".into()).unwrap();
let user = ath.db().get_user_by_email(&email).await.unwrap();
assert!(user.custom_data.is_none());
}
#[tokio::test]
async fn register_with_custom_fields_stores_json() {
let (ath, config) = setup().await;
let app = test_app_with_schema(ath.clone(), config);
let csrf = get_csrf_token(&app).await;
let req = register_request_with_custom(
&csrf,
"custom@example.com",
"password123",
"password123",
"",
&[
("company", "Acme Corp"),
("age", "30"),
("newsletter", "true"),
],
);
let resp = app.oneshot(req).await.unwrap();
assert_eq!(
resp.status(),
StatusCode::SEE_OTHER,
"should redirect on success"
);
let email = Email::new("custom@example.com".into()).unwrap();
let user = ath.db().get_user_by_email(&email).await.unwrap();
let data = user
.custom_data
.as_ref()
.expect("custom_data should be stored");
assert_eq!(data["company"], "Acme Corp");
assert_eq!(data["age"], 30);
assert_eq!(data["newsletter"], true);
}
#[tokio::test]
async fn register_with_invalid_custom_fields_shows_error() {
let (ath, config) = setup().await;
let app = test_app_with_schema(ath, config);
let csrf = get_csrf_token(&app).await;
let req = register_request_with_custom(
&csrf,
"invalid@example.com",
"password123",
"password123",
"",
&[("age", "25")],
);
let resp = app.oneshot(req).await.unwrap();
assert_eq!(
resp.status(),
StatusCode::OK,
"should re-render form on validation failure"
);
let html = body_string(resp).await;
assert!(
html.contains("company") || html.contains("required"),
"should show validation error about missing required field"
);
assert!(
html.contains("invalid@example.com"),
"email should be preserved on error re-render"
);
}
#[tokio::test]
async fn register_custom_field_type_coercion() {
let (ath, config) = setup().await;
let app = test_app_with_schema(ath.clone(), config);
let csrf = get_csrf_token(&app).await;
let req = register_request_with_custom(
&csrf,
"coerce@example.com",
"password123",
"password123",
"",
&[("company", "Test Co"), ("age", "42")],
);
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
let email = Email::new("coerce@example.com".into()).unwrap();
let user = ath.db().get_user_by_email(&email).await.unwrap();
let data = user.custom_data.as_ref().expect("custom_data");
assert!(data["age"].is_i64(), "age should be coerced to integer");
assert_eq!(data["age"], 42);
assert_eq!(data["newsletter"], false);
}
#[tokio::test]
async fn get_register_with_schema_renders_custom_fields() {
let (ath, config) = setup().await;
let app = test_app_with_schema(ath, config);
let req = Request::builder()
.uri("/register")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let html = body_string(resp).await;
assert!(
html.contains("custom_data[company]"),
"should render company custom field"
);
assert!(html.contains("Company Name"), "should render field label");
assert!(
html.contains(r#"name="custom_data[company]"#) && html.contains(r#"minlength="1""#),
"string custom fields should preserve minLength"
);
assert!(
html.contains(r#"name="custom_data[age]"#)
&& html.contains(r#"min="0""#)
&& html.contains(r#"max="120""#)
&& html.contains(r#"value="21""#),
"number custom fields should preserve min/max/default attributes"
);
}
fn test_app_with_events(
ath: AllowThem,
config: RegisterConfig,
events_tx: allowthem_core::LifecycleEventSender,
base_url: String,
) -> Router {
register_routes(
config.is_production,
None,
Some(events_tx),
Some(base_url),
config.oauth_providers.clone(),
config.public_registration,
)
.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware))
.layer(axum::middleware::from_fn_with_state(
ath.clone(),
crate::cors::inject_ath_into_extensions,
))
}
#[tokio::test]
async fn post_register_publishes_registered_event() {
let (ath, config) = setup().await;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let app = test_app_with_events(ath, config, tx, "http://test".into());
let csrf = get_csrf_token(&app).await;
let req = register_request(
&csrf,
"events@example.com",
"password123",
"password123",
"",
);
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
let event = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv())
.await
.expect("event received before timeout")
.expect("channel sender still alive");
match event {
LifecycleEvent::Registered(e) => {
assert!(matches!(e.source, RegistrationSource::Password));
assert_eq!(e.user.email.as_str(), "events@example.com");
assert_eq!(e.ctx.base_url, "http://test");
}
other => panic!("unexpected event variant: {other:?}"),
}
}
#[tokio::test]
async fn post_register_does_not_hang_when_receiver_dropped() {
let (ath, config) = setup().await;
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
drop(rx);
let app = test_app_with_events(ath, config, tx, "http://test".into());
let csrf = get_csrf_token(&app).await;
let req = register_request(
&csrf,
"dropped@example.com",
"password123",
"password123",
"",
);
let resp = tokio::time::timeout(std::time::Duration::from_secs(10), app.oneshot(req))
.await
.expect("register handler returned in time")
.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
}
#[tokio::test]
async fn post_register_without_events_sender_succeeds() {
let (ath, config) = setup().await;
let app = test_app(ath, config);
let csrf = get_csrf_token(&app).await;
let req = register_request(
&csrf,
"no-events@example.com",
"password123",
"password123",
"",
);
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
}
#[tokio::test]
async fn render_register_fragment_composes_main_and_oob_head() {
let (_ath, config) = setup().await;
let empty: std::collections::HashMap<String, String> = std::collections::HashMap::new();
let html = render_register_fragment(
&config,
RegisterFormParams {
csrf_token: "tok",
email: "",
username: "",
error: "",
client_id: None,
branding: None,
custom_fields: &[],
custom_values: &empty,
token: None,
email_readonly: false,
registration_disabled: false,
},
)
.unwrap()
.0;
assert!(
html.contains("<main class=\"wf-auth-form\">"),
"fragment must include the <main> root"
);
assert!(
html.contains("<title hx-swap-oob=\"true\">"),
"fragment must include the OOB <title> tag"
);
assert!(
html.contains("id=\"wf-screen-label\""),
"fragment must include the OOB #wf-screen-label span"
);
}
}