use std::collections::HashSet;
use time::Duration;
use crate::error::{AuthError, OAuthError};
#[derive(Debug, Clone)]
pub struct AuthConfig {
pub secret: String,
pub session_ttl: Duration,
pub verification_ttl: Duration,
pub reset_ttl: Duration,
pub token_length: usize,
pub email: EmailConfig,
pub cookie: CookieConfig,
pub oauth: OAuthConfig,
}
#[derive(Debug, Clone)]
pub struct EmailConfig {
pub send_verification_on_signup: bool,
pub require_verification_to_login: bool,
pub auto_sign_in_after_signup: bool,
pub auto_sign_in_after_verification: bool,
}
#[derive(Debug, Clone)]
pub struct CookieConfig {
pub name: String,
pub http_only: bool,
pub secure: bool,
pub same_site: SameSite,
pub path: String,
pub domain: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SameSite {
Strict,
Lax,
None,
}
#[derive(Debug, Clone)]
pub struct OAuthConfig {
pub providers: Vec<OAuthProviderEntry>,
pub allow_implicit_account_linking: bool,
pub success_redirect: Option<String>,
pub error_redirect: Option<String>,
}
#[derive(Debug, Clone)]
pub struct OAuthProviderEntry {
pub provider_id: String,
pub client_id: String,
pub client_secret: String,
pub redirect_url: String,
pub auth_url: Option<String>,
pub token_url: Option<String>,
pub userinfo_url: Option<String>,
}
impl Default for OAuthConfig {
fn default() -> Self {
Self {
providers: vec![],
allow_implicit_account_linking: true,
success_redirect: None,
error_redirect: None,
}
}
}
impl OAuthConfig {
pub fn validate(&self) -> Result<(), AuthError> {
let mut seen_provider_ids = HashSet::new();
for provider in &self.providers {
if !seen_provider_ids.insert(provider.provider_id.as_str()) {
return Err(AuthError::OAuth(OAuthError::Misconfigured {
message: format!("duplicate provider_id: {}", provider.provider_id),
}));
}
match provider.provider_id.as_str() {
"google" | "github" => {}
_ => {
return Err(AuthError::OAuth(OAuthError::UnsupportedProvider {
provider: provider.provider_id.clone(),
}));
}
}
if provider.client_id.trim().is_empty() {
return Err(AuthError::OAuth(OAuthError::Misconfigured {
message: format!("provider {} has empty client_id", provider.provider_id),
}));
}
if provider.client_secret.trim().is_empty() {
return Err(AuthError::OAuth(OAuthError::Misconfigured {
message: format!("provider {} has empty client_secret", provider.provider_id),
}));
}
if provider.redirect_url.trim().is_empty() {
return Err(AuthError::OAuth(OAuthError::Misconfigured {
message: format!("provider {} has empty redirect_url", provider.provider_id),
}));
}
validate_url(
"redirect_url",
&provider.provider_id,
&provider.redirect_url,
)?;
if let Some(auth_url) = provider.auth_url.as_deref() {
validate_url("auth_url", &provider.provider_id, auth_url)?;
}
if let Some(token_url) = provider.token_url.as_deref() {
validate_url("token_url", &provider.provider_id, token_url)?;
}
if let Some(userinfo_url) = provider.userinfo_url.as_deref() {
validate_url("userinfo_url", &provider.provider_id, userinfo_url)?;
}
}
Ok(())
}
}
fn validate_url(field: &str, provider_id: &str, value: &str) -> Result<(), AuthError> {
if value.trim().is_empty() {
return Err(AuthError::OAuth(OAuthError::Misconfigured {
message: format!("provider {provider_id} has empty {field}"),
}));
}
reqwest::Url::parse(value).map_err(|e| {
AuthError::OAuth(OAuthError::Misconfigured {
message: format!("provider {provider_id} has invalid {field}: {e}"),
})
})?;
Ok(())
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
secret: String::new(),
session_ttl: Duration::days(30),
verification_ttl: Duration::hours(1),
reset_ttl: Duration::hours(1),
token_length: 32,
email: EmailConfig::default(),
cookie: CookieConfig::default(),
oauth: OAuthConfig::default(),
}
}
}
impl Default for EmailConfig {
fn default() -> Self {
Self {
send_verification_on_signup: true,
require_verification_to_login: false,
auto_sign_in_after_signup: true,
auto_sign_in_after_verification: false,
}
}
}
impl Default for CookieConfig {
fn default() -> Self {
Self {
name: "rs_auth_session".to_string(),
http_only: true,
secure: true,
same_site: SameSite::Lax,
path: "/".to_string(),
domain: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_has_sane_values() {
let config = AuthConfig::default();
assert_eq!(
config.session_ttl,
Duration::days(30),
"session_ttl should be 30 days"
);
assert_eq!(config.token_length, 32, "token_length should be 32");
assert_eq!(
config.cookie.name, "rs_auth_session",
"cookie name should be 'rs_auth_session'"
);
assert_eq!(
config.verification_ttl,
Duration::hours(1),
"verification_ttl should be 1 hour"
);
assert_eq!(
config.reset_ttl,
Duration::hours(1),
"reset_ttl should be 1 hour"
);
assert!(config.cookie.http_only, "cookie should be http_only");
assert!(config.cookie.secure, "cookie should be secure");
assert_eq!(
config.cookie.same_site,
SameSite::Lax,
"cookie same_site should be Lax"
);
assert_eq!(config.cookie.path, "/", "cookie path should be '/'");
assert_eq!(config.cookie.domain, None, "cookie domain should be None");
}
#[test]
fn oauth_config_rejects_duplicate_provider_ids() {
let config = OAuthConfig {
providers: vec![
OAuthProviderEntry {
provider_id: "google".to_string(),
client_id: "a".to_string(),
client_secret: "b".to_string(),
redirect_url: "https://example.com/callback/google".to_string(),
auth_url: None,
token_url: None,
userinfo_url: None,
},
OAuthProviderEntry {
provider_id: "google".to_string(),
client_id: "c".to_string(),
client_secret: "d".to_string(),
redirect_url: "https://example.com/callback/google-2".to_string(),
auth_url: None,
token_url: None,
userinfo_url: None,
},
],
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn oauth_config_rejects_unsupported_provider_ids() {
let config = OAuthConfig {
providers: vec![OAuthProviderEntry {
provider_id: "gitlab".to_string(),
client_id: "a".to_string(),
client_secret: "b".to_string(),
redirect_url: "https://example.com/callback/gitlab".to_string(),
auth_url: None,
token_url: None,
userinfo_url: None,
}],
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn oauth_config_rejects_invalid_urls() {
let config = OAuthConfig {
providers: vec![OAuthProviderEntry {
provider_id: "google".to_string(),
client_id: "a".to_string(),
client_secret: "b".to_string(),
redirect_url: "not-a-url".to_string(),
auth_url: None,
token_url: None,
userinfo_url: None,
}],
..Default::default()
};
assert!(config.validate().is_err());
}
}