use crate::email::EmailProvider;
use crate::error::AuthError;
use chrono::Duration;
use std::collections::HashMap;
use std::sync::Arc;
pub mod core_paths {
pub const OK: &str = "/ok";
pub const ERROR: &str = "/error";
pub const HEALTH: &str = "/health";
pub const OPENAPI_SPEC: &str = "/reference/openapi.json";
pub const UPDATE_USER: &str = "/update-user";
pub const DELETE_USER: &str = "/delete-user";
pub const CHANGE_EMAIL: &str = "/change-email";
pub const DELETE_USER_CALLBACK: &str = "/delete-user/callback";
}
#[derive(Clone)]
pub struct AuthConfig {
pub secret: String,
pub app_name: String,
pub base_url: String,
pub base_path: String,
pub trusted_origins: Vec<String>,
pub disabled_paths: Vec<String>,
pub session: SessionConfig,
pub jwt: JwtConfig,
pub password: PasswordConfig,
pub account: AccountConfig,
pub email_provider: Option<Arc<dyn EmailProvider>>,
pub advanced: AdvancedConfig,
}
#[derive(Debug, Clone)]
pub struct AccountConfig {
pub update_account_on_sign_in: bool,
pub account_linking: AccountLinkingConfig,
pub encrypt_oauth_tokens: bool,
}
#[derive(Debug, Clone)]
pub struct AccountLinkingConfig {
pub enabled: bool,
pub trusted_providers: Vec<String>,
pub allow_different_emails: bool,
pub allow_unlinking_all: bool,
pub update_user_info_on_link: bool,
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub expires_in: Duration,
pub update_age: Option<Duration>,
pub disable_session_refresh: bool,
pub fresh_age: Option<Duration>,
pub cookie_name: String,
pub cookie_secure: bool,
pub cookie_http_only: bool,
pub cookie_same_site: SameSite,
pub cookie_cache: Option<CookieCacheConfig>,
}
#[derive(Debug, Clone)]
pub struct JwtConfig {
pub expires_in: Duration,
pub algorithm: String,
pub issuer: Option<String>,
pub audience: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PasswordConfig {
pub min_length: usize,
pub require_uppercase: bool,
pub require_lowercase: bool,
pub require_numbers: bool,
pub require_special: bool,
pub argon2_config: Argon2Config,
}
#[derive(Debug, Clone)]
pub struct Argon2Config {
pub memory_cost: u32,
pub time_cost: u32,
pub parallelism: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SameSite {
Strict,
Lax,
None,
}
#[derive(Debug, Clone)]
pub struct CookieCacheConfig {
pub enabled: bool,
pub max_age: Duration,
pub strategy: CookieCacheStrategy,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CookieCacheStrategy {
Compact,
Jwt,
Jwe,
}
impl Default for CookieCacheConfig {
fn default() -> Self {
Self {
enabled: false,
max_age: Duration::minutes(5),
strategy: CookieCacheStrategy::Compact,
}
}
}
impl Default for AccountConfig {
fn default() -> Self {
Self {
update_account_on_sign_in: true,
account_linking: AccountLinkingConfig::default(),
encrypt_oauth_tokens: false,
}
}
}
impl Default for AccountLinkingConfig {
fn default() -> Self {
Self {
enabled: true,
trusted_providers: Vec::new(),
allow_different_emails: false,
allow_unlinking_all: false,
update_user_info_on_link: false,
}
}
}
impl std::fmt::Display for SameSite {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SameSite::Strict => f.write_str("Strict"),
SameSite::Lax => f.write_str("Lax"),
SameSite::None => f.write_str("None"),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AdvancedConfig {
pub ip_address: IpAddressConfig,
pub disable_csrf_check: bool,
pub disable_origin_check: bool,
pub cross_sub_domain_cookies: Option<CrossSubDomainConfig>,
pub cookies: HashMap<String, CookieOverride>,
pub default_cookie_attributes: CookieAttributes,
pub cookie_prefix: Option<String>,
pub database: AdvancedDatabaseConfig,
pub trusted_proxy_headers: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct IpAddressConfig {
pub headers: Vec<String>,
pub disable_ip_tracking: bool,
}
#[derive(Debug, Clone)]
pub struct CrossSubDomainConfig {
pub domain: String,
}
#[derive(Debug, Clone, Default)]
pub struct CookieAttributes {
pub secure: Option<bool>,
pub http_only: Option<bool>,
pub same_site: Option<SameSite>,
pub path: Option<String>,
pub max_age: Option<i64>,
pub domain: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct CookieOverride {
pub name: Option<String>,
pub attributes: CookieAttributes,
}
#[derive(Debug, Clone)]
pub struct AdvancedDatabaseConfig {
pub default_find_many_limit: usize,
pub use_number_id: bool,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
secret: String::new(),
app_name: "Better Auth".to_string(),
base_url: "http://localhost:3000".to_string(),
base_path: "/api/auth".to_string(),
trusted_origins: Vec::new(),
disabled_paths: Vec::new(),
session: SessionConfig::default(),
jwt: JwtConfig::default(),
password: PasswordConfig::default(),
account: AccountConfig::default(),
email_provider: None,
advanced: AdvancedConfig::default(),
}
}
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
expires_in: Duration::hours(24 * 7), update_age: Some(Duration::hours(24)), disable_session_refresh: false,
fresh_age: None,
cookie_name: "better-auth.session-token".to_string(),
cookie_secure: true,
cookie_http_only: true,
cookie_same_site: SameSite::Lax,
cookie_cache: None,
}
}
}
impl Default for IpAddressConfig {
fn default() -> Self {
Self {
headers: vec!["x-forwarded-for".to_string(), "x-real-ip".to_string()],
disable_ip_tracking: false,
}
}
}
impl Default for AdvancedDatabaseConfig {
fn default() -> Self {
Self {
default_find_many_limit: 100,
use_number_id: false,
}
}
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
expires_in: Duration::hours(24), algorithm: "HS256".to_string(),
issuer: None,
audience: None,
}
}
}
impl Default for PasswordConfig {
fn default() -> Self {
Self {
min_length: 8,
require_uppercase: false,
require_lowercase: false,
require_numbers: false,
require_special: false,
argon2_config: Argon2Config::default(),
}
}
}
impl Default for Argon2Config {
fn default() -> Self {
Self {
memory_cost: 4096, time_cost: 3, parallelism: 1, }
}
}
impl AuthConfig {
pub fn new(secret: impl Into<String>) -> Self {
Self {
secret: secret.into(),
..Default::default()
}
}
pub fn app_name(mut self, name: impl Into<String>) -> Self {
self.app_name = name.into();
self
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn account(mut self, account: AccountConfig) -> Self {
self.account = account;
self
}
pub fn base_path(mut self, path: impl Into<String>) -> Self {
self.base_path = path.into();
self
}
pub fn trusted_origin(mut self, origin: impl Into<String>) -> Self {
self.trusted_origins.push(origin.into());
self
}
pub fn trusted_origins(mut self, origins: Vec<String>) -> Self {
self.trusted_origins = origins;
self
}
pub fn disabled_path(mut self, path: impl Into<String>) -> Self {
self.disabled_paths.push(path.into());
self
}
pub fn disabled_paths(mut self, paths: Vec<String>) -> Self {
self.disabled_paths = paths;
self
}
pub fn session_expires_in(mut self, duration: Duration) -> Self {
self.session.expires_in = duration;
self
}
pub fn session_update_age(mut self, duration: Duration) -> Self {
self.session.update_age = Some(duration);
self
}
pub fn disable_session_refresh(mut self, disabled: bool) -> Self {
self.session.disable_session_refresh = disabled;
self
}
pub fn session_fresh_age(mut self, duration: Duration) -> Self {
self.session.fresh_age = Some(duration);
self
}
pub fn session_cookie_cache(mut self, config: CookieCacheConfig) -> Self {
self.session.cookie_cache = Some(config);
self
}
pub fn jwt_expires_in(mut self, duration: Duration) -> Self {
self.jwt.expires_in = duration;
self
}
pub fn password_min_length(mut self, length: usize) -> Self {
self.password.min_length = length;
self
}
pub fn advanced(mut self, advanced: AdvancedConfig) -> Self {
self.advanced = advanced;
self
}
pub fn cookie_prefix(mut self, prefix: impl Into<String>) -> Self {
self.advanced.cookie_prefix = Some(prefix.into());
self
}
pub fn disable_csrf_check(mut self, disabled: bool) -> Self {
self.advanced.disable_csrf_check = disabled;
self
}
pub fn cross_sub_domain_cookies(mut self, domain: impl Into<String>) -> Self {
self.advanced.cross_sub_domain_cookies = Some(CrossSubDomainConfig {
domain: domain.into(),
});
self
}
pub fn is_origin_trusted(&self, origin: &str) -> bool {
if let Some(base_origin) = extract_origin(&self.base_url)
&& origin == base_origin
{
return true;
}
self.trusted_origins.iter().any(|pattern| {
let pattern_origin = extract_pattern_origin(pattern);
glob_match::glob_match(&pattern_origin, origin)
})
}
pub fn is_path_disabled(&self, path: &str) -> bool {
self.disabled_paths.iter().any(|disabled| disabled == path)
}
pub fn is_redirect_target_trusted(&self, target: &str) -> bool {
if target.chars().any(|c| c.is_control() || c == '"') {
return false;
}
if is_authority_smuggling(target) {
return false;
}
if self.advanced.disable_origin_check {
return target.starts_with('/') || extract_origin(target).is_some();
}
if target.starts_with('/') {
return true;
}
match extract_origin(target) {
Some(origin) => self.is_origin_trusted(&origin),
None => false,
}
}
pub fn is_absolute_trusted_callback_url(&self, target: &str) -> bool {
if !self.is_redirect_target_trusted(target) {
return false;
}
extract_origin(target).is_some()
}
pub fn validate(&self) -> Result<(), AuthError> {
if self.secret.is_empty() {
return Err(AuthError::config("Secret key cannot be empty"));
}
if self.secret.len() < 32 {
return Err(AuthError::config(
"Secret key must be at least 32 characters",
));
}
Ok(())
}
}
pub const DEFAULT_MAX_BODY_BYTES: usize = 1024 * 1024;
pub fn extract_origin(url: &str) -> Option<String> {
let parsed = ::url::Url::parse(url).ok()?;
if !matches!(parsed.scheme(), "http" | "https") {
return None;
}
match parsed.origin() {
::url::Origin::Tuple(..) => Some(parsed.origin().ascii_serialization()),
::url::Origin::Opaque(_) => None,
}
}
fn extract_pattern_origin(pattern: &str) -> String {
if !pattern.contains('*')
&& let Some(canonical) = extract_origin(pattern)
{
return canonical;
}
let Some(scheme_end) = pattern.find("://") else {
return String::new();
};
let scheme = pattern[..scheme_end].to_ascii_lowercase();
let rest = &pattern[scheme_end + 3..];
let host_end = rest.find('/').unwrap_or(rest.len());
let authority = &rest[..host_end];
let (host, port_suffix) = match authority.rfind(':') {
Some(idx)
if authority[idx + 1..]
.chars()
.all(|c| c.is_ascii_digit() || c == '*') =>
{
(&authority[..idx], &authority[idx..])
}
_ => (authority, ""),
};
let canonical_host: String = host
.split('.')
.map(|label| {
if label.contains('*') || label.is_ascii() {
label.to_ascii_lowercase()
} else {
idna::domain_to_ascii(label).unwrap_or_else(|_| label.to_ascii_lowercase())
}
})
.collect::<Vec<_>>()
.join(".");
let port_suffix = match (scheme.as_str(), port_suffix) {
("http", ":80") | ("https", ":443") => "",
_ => port_suffix,
};
format!("{}://{}{}", scheme, canonical_host, port_suffix)
}
fn is_authority_smuggling(target: &str) -> bool {
let trimmed = target.trim_start_matches(|c: char| c.is_whitespace());
if trimmed.starts_with('\\') {
return true;
}
if trimmed.starts_with("//") {
return true;
}
if let Some(rest) = trimmed.strip_prefix('/')
&& (rest.starts_with('/') || rest.starts_with('\\'))
{
return true;
}
let encoded_bypass = ["/%2f", "/%2F", "/%5c", "/%5C", "%2f", "%2F", "%5c", "%5C"];
if encoded_bypass.iter().any(|p| trimmed.starts_with(p)) {
return true;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
fn config_with(trusted: Vec<&str>) -> AuthConfig {
AuthConfig {
base_url: "https://app.example.com".into(),
trusted_origins: trusted.into_iter().map(String::from).collect(),
..AuthConfig::default()
}
}
#[test]
fn redirect_target_allows_relative_path() {
let cfg = config_with(vec![]);
assert!(cfg.is_redirect_target_trusted("/dashboard"));
assert!(cfg.is_redirect_target_trusted("/reset-password?token=abc"));
}
#[test]
fn redirect_target_rejects_protocol_relative() {
let cfg = config_with(vec![]);
assert!(!cfg.is_redirect_target_trusted("//evil.com/x"));
assert!(!cfg.is_redirect_target_trusted("//evil.com"));
}
#[test]
fn redirect_target_allows_base_url_origin() {
let cfg = config_with(vec![]);
assert!(cfg.is_redirect_target_trusted("https://app.example.com/dashboard"));
}
#[test]
fn redirect_target_allows_trusted_origin() {
let cfg = config_with(vec!["https://admin.example.com"]);
assert!(cfg.is_redirect_target_trusted("https://admin.example.com/callback"));
}
#[test]
fn redirect_target_rejects_untrusted_origin() {
let cfg = config_with(vec!["https://admin.example.com"]);
assert!(!cfg.is_redirect_target_trusted("https://evil.com/cb"));
}
#[test]
fn redirect_target_rejects_unparseable_absolute() {
let cfg = config_with(vec![]);
assert!(!cfg.is_redirect_target_trusted("javascript:alert(1)"));
assert!(!cfg.is_redirect_target_trusted("data:text/html,x"));
}
#[test]
fn redirect_target_bypass_does_not_cover_authority_smuggling() {
let mut cfg = config_with(vec![]);
cfg.advanced.disable_origin_check = true;
assert!(cfg.is_redirect_target_trusted("https://evil.com/cb"));
assert!(cfg.is_redirect_target_trusted("/dashboard"));
assert!(!cfg.is_redirect_target_trusted("//evil.com"));
assert!(!cfg.is_redirect_target_trusted("/\\evil.com"));
assert!(!cfg.is_redirect_target_trusted("\\\\evil.com"));
}
#[test]
fn redirect_target_rejects_backslash_authority_bypass() {
let cfg = config_with(vec![]);
assert!(!cfg.is_redirect_target_trusted("/\\evil.com"));
assert!(!cfg.is_redirect_target_trusted("/\\\\evil.com"));
assert!(!cfg.is_redirect_target_trusted("\\evil.com"));
assert!(!cfg.is_redirect_target_trusted("\\\\evil.com"));
assert!(!cfg.is_redirect_target_trusted("\\/evil.com"));
assert!(!cfg.is_redirect_target_trusted(" //evil.com"));
assert!(!cfg.is_redirect_target_trusted("\t/\\evil.com"));
}
#[test]
fn redirect_target_strips_userinfo_when_comparing_origin() {
let cfg = config_with(vec![]);
assert!(!cfg.is_redirect_target_trusted("https://app.example.com@evil.com/x"));
}
#[test]
fn redirect_target_allows_same_origin_with_query_and_fragment() {
let cfg = config_with(vec![]);
assert!(cfg.is_redirect_target_trusted("https://app.example.com?retry=1"));
assert!(cfg.is_redirect_target_trusted("https://app.example.com#/route"));
assert!(cfg.is_redirect_target_trusted("https://app.example.com/path?x=1#y"));
}
#[test]
fn redirect_target_rejects_non_http_schemes() {
let cfg = config_with(vec![]);
assert!(!cfg.is_redirect_target_trusted("javascript:alert(1)"));
assert!(!cfg.is_redirect_target_trusted("data:text/html,x"));
assert!(!cfg.is_redirect_target_trusted("file:///etc/passwd"));
assert!(!cfg.is_redirect_target_trusted("ftp://example.com/"));
}
#[test]
fn redirect_target_preserves_non_default_port_in_origin_match() {
let cfg = config_with(vec!["https://admin.example.com:8443"]);
assert!(cfg.is_redirect_target_trusted("https://admin.example.com:8443/x"));
assert!(!cfg.is_redirect_target_trusted("https://admin.example.com/x"));
}
#[test]
fn redirect_target_rejects_control_chars_and_quotes() {
let cfg = config_with(vec![]);
assert!(!cfg.is_redirect_target_trusted("/path\r\nEvil-Header: x"));
assert!(!cfg.is_redirect_target_trusted("/path\nEvil: x"));
assert!(!cfg.is_redirect_target_trusted("/path\"><script>"));
assert!(!cfg.is_redirect_target_trusted("/path\u{0000}null"));
assert!(!cfg.is_redirect_target_trusted("https://app.example.com/x\r\n"));
}
#[test]
fn trusted_origins_with_explicit_default_ports_still_match() {
let cfg = config_with(vec![
"https://admin.example.com:443",
"http://legacy.example.com:80",
]);
assert!(cfg.is_origin_trusted("https://admin.example.com"));
assert!(cfg.is_origin_trusted("http://legacy.example.com"));
assert!(cfg.is_redirect_target_trusted("https://admin.example.com/cb"));
assert!(cfg.is_redirect_target_trusted("http://legacy.example.com/cb"));
}
#[test]
fn redirect_target_bypass_still_rejects_dangerous_schemes() {
let mut cfg = config_with(vec![]);
cfg.advanced.disable_origin_check = true;
assert!(!cfg.is_redirect_target_trusted("javascript:alert(1)"));
assert!(!cfg.is_redirect_target_trusted("data:text/html,<script>x</script>"));
assert!(!cfg.is_redirect_target_trusted("file:///etc/passwd"));
assert!(!cfg.is_redirect_target_trusted("ftp://example.com/"));
assert!(cfg.is_redirect_target_trusted("/dashboard"));
assert!(cfg.is_redirect_target_trusted("https://evil.com/cb"));
}
#[test]
fn trusted_origins_wildcard_idn_matches_punycode_callback() {
let cfg = config_with(vec!["https://*.bücher.example"]);
assert!(cfg.is_origin_trusted("https://shop.xn--bcher-kva.example"));
assert!(cfg.is_redirect_target_trusted("https://shop.xn--bcher-kva.example/path"));
}
#[test]
fn trusted_origins_pattern_lowercases_scheme_and_host() {
let cfg = config_with(vec!["HTTPS://APP.Example.COM"]);
assert!(cfg.is_origin_trusted("https://app.example.com"));
assert!(cfg.is_redirect_target_trusted("https://app.example.com/x"));
}
#[test]
fn trusted_origins_punycode_idn_matches_punycode_callback() {
let cfg = config_with(vec!["https://bücher.example"]);
assert!(cfg.is_origin_trusted("https://xn--bcher-kva.example"));
assert!(cfg.is_redirect_target_trusted("https://xn--bcher-kva.example/book"));
}
#[test]
fn absolute_trusted_callback_url_rejects_relative_paths() {
let cfg = config_with(vec!["https://admin.example.com"]);
assert!(cfg.is_redirect_target_trusted("/dashboard"));
assert!(!cfg.is_absolute_trusted_callback_url("/dashboard"));
assert!(!cfg.is_absolute_trusted_callback_url("/reset?token=x"));
assert!(cfg.is_absolute_trusted_callback_url("https://admin.example.com/cb"));
assert!(!cfg.is_absolute_trusted_callback_url("https://evil.com/cb"));
assert!(!cfg.is_absolute_trusted_callback_url("javascript:alert(1)"));
}
#[test]
fn redirect_target_rejects_percent_encoded_authority_bypass() {
let cfg = config_with(vec![]);
assert!(!cfg.is_redirect_target_trusted("/%2Fevil.com"));
assert!(!cfg.is_redirect_target_trusted("/%2fevil.com"));
assert!(!cfg.is_redirect_target_trusted("/%5Cevil.com"));
assert!(!cfg.is_redirect_target_trusted("/%5cevil.com"));
assert!(!cfg.is_redirect_target_trusted("%2Fevil.com"));
}
#[test]
fn redirect_target_rejects_bare_backslash_under_disable_origin_check() {
let mut cfg = config_with(vec![]);
cfg.advanced.disable_origin_check = true;
assert!(!cfg.is_redirect_target_trusted("\\evil.com"));
assert!(!cfg.is_redirect_target_trusted(" \\evil.com"));
}
#[test]
fn trusted_origins_supports_port_and_scheme_globs() {
let cfg = config_with(vec![
"http://localhost:*",
"https://*.example.com",
"*://api.staging.test",
]);
assert!(cfg.is_origin_trusted("http://localhost:3000"));
assert!(cfg.is_origin_trusted("http://localhost:8080"));
assert!(cfg.is_origin_trusted("https://app.example.com"));
assert!(cfg.is_origin_trusted("http://api.staging.test"));
assert!(cfg.is_origin_trusted("https://api.staging.test"));
assert!(!cfg.is_origin_trusted("http://localhost.evil.com"));
}
}