#![allow(missing_copy_implementations)]
use std::path::PathBuf;
use std::time::Duration;
use chrono::{DateTime, FixedOffset, Utc};
use cot_core::error::impl_into_cot_error;
use derive_builder::Builder;
use derive_more::with_trait::{Debug, From};
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
use thiserror::Error;
#[cfg(feature = "email")]
use crate::email::transport::smtp::Mechanism;
use crate::utils::chrono::DateTimeWithOffsetAdapter;
#[derive(Debug, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
#[non_exhaustive]
pub struct ProjectConfig {
pub debug: bool,
pub register_panic_hook: bool,
pub secret_key: SecretKey,
pub fallback_secret_keys: Vec<SecretKey>,
pub auth_backend: AuthBackendConfig,
#[cfg(feature = "db")]
pub database: DatabaseConfig,
#[cfg(feature = "cache")]
pub cache: CacheConfig,
pub static_files: StaticFilesConfig,
pub middlewares: MiddlewareConfig,
#[cfg(feature = "email")]
pub email: EmailConfig,
}
const fn default_debug() -> bool {
cfg!(debug_assertions)
}
impl Default for ProjectConfig {
fn default() -> Self {
ProjectConfig::builder().build()
}
}
impl ProjectConfig {
#[must_use]
pub fn builder() -> ProjectConfigBuilder {
ProjectConfigBuilder::default()
}
#[must_use]
pub fn dev_default() -> ProjectConfig {
let mut builder = ProjectConfig::builder();
builder.debug(true).register_panic_hook(true);
#[cfg(feature = "db")]
builder.database(DatabaseConfig::builder().url("sqlite::memory:").build());
builder.build()
}
pub fn from_toml(toml_content: &str) -> crate::Result<ProjectConfig> {
let config: ProjectConfig = toml::from_str(toml_content).map_err(ParseConfig)?;
Ok(config)
}
}
#[derive(Debug, Error)]
#[error("could not parse the config: {0}")]
struct ParseConfig(#[from] toml::de::Error);
impl_into_cot_error!(ParseConfig);
impl ProjectConfigBuilder {
#[must_use]
pub fn build(&self) -> ProjectConfig {
let debug = self.debug.unwrap_or(default_debug());
ProjectConfig {
debug,
register_panic_hook: self.register_panic_hook.unwrap_or(true),
secret_key: self.secret_key.clone().unwrap_or_default(),
fallback_secret_keys: self.fallback_secret_keys.clone().unwrap_or_default(),
auth_backend: self.auth_backend.unwrap_or_default(),
#[cfg(feature = "db")]
database: self.database.clone().unwrap_or_default(),
#[cfg(feature = "cache")]
cache: self.cache.clone().unwrap_or_default(),
static_files: self.static_files.clone().unwrap_or_default(),
middlewares: self.middlewares.clone().unwrap_or_default(),
#[cfg(feature = "email")]
email: self.email.clone().unwrap_or_default(),
}
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum AuthBackendConfig {
#[default]
None,
#[cfg(feature = "db")]
Database,
}
#[cfg(feature = "db")]
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
#[non_exhaustive]
pub struct DatabaseConfig {
#[builder(setter(into, strip_option), default)]
pub url: Option<DatabaseUrl>,
}
#[cfg(feature = "db")]
impl DatabaseConfigBuilder {
#[must_use]
pub fn build(&self) -> DatabaseConfig {
DatabaseConfig {
url: self.url.clone().expect("Database URL is required"),
}
}
}
#[cfg(feature = "db")]
impl DatabaseConfig {
#[must_use]
pub fn builder() -> DatabaseConfigBuilder {
DatabaseConfigBuilder::default()
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Timeout {
Never,
After(Duration),
AtDateTime(DateTime<FixedOffset>),
}
impl Timeout {
#[must_use]
pub fn is_expired(&self, insertion_time: Option<DateTime<FixedOffset>>) -> bool {
match self {
Timeout::Never => false,
Timeout::After(dur) => {
if let Some(time) = insertion_time {
let expiry_time = time + chrono::Duration::from_std(*dur).unwrap_or_default();
let now_in_offset = Utc::now().with_timezone(time.offset());
return now_in_offset >= expiry_time;
}
panic!("insertion_time is required for Timeout::After expiry check");
}
Timeout::AtDateTime(dt) => {
let now_in_offset = Utc::now().with_timezone(dt.offset());
now_in_offset >= *dt
}
}
}
#[must_use]
#[expect(clippy::missing_panics_doc)]
pub fn canonicalize(self) -> Self {
match self {
Timeout::After(duration) => {
let time_now = Utc::now().with_timezone(&FixedOffset::east_opt(0).expect("conversion to FixedOffset(0) should not fail since 0 is a valid timezone offset"));
let expiry_time =
time_now + chrono::Duration::from_std(duration).unwrap_or_default();
Timeout::AtDateTime(expiry_time)
}
timeout => timeout,
}
}
}
impl Default for Timeout {
fn default() -> Self {
Self::After(Duration::from_secs(300))
}
}
#[cfg(feature = "cache")]
const MAX_RETRIES_DEFAULT: u32 = 3;
#[cfg(feature = "cache")]
#[derive(Debug, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
#[non_exhaustive]
pub struct CacheConfig {
pub max_retries: u32,
#[serde(with = "crate::serializers::cache_timeout")]
pub timeout: Timeout,
#[builder(setter(into, strip_option), default)]
pub prefix: Option<String>,
#[builder(default)]
pub store: CacheStoreConfig,
}
#[cfg(feature = "cache")]
impl CacheConfigBuilder {
#[must_use]
pub fn build(&self) -> CacheConfig {
CacheConfig {
max_retries: self.max_retries.unwrap_or(MAX_RETRIES_DEFAULT),
timeout: self.timeout.unwrap_or_default(),
prefix: self.prefix.clone().unwrap_or_default(),
store: self.store.clone().unwrap_or_default(),
}
}
}
#[cfg(feature = "cache")]
impl Default for CacheConfig {
fn default() -> Self {
CacheConfig::builder().build()
}
}
#[cfg(feature = "cache")]
impl CacheConfig {
#[must_use]
pub fn builder() -> CacheConfigBuilder {
CacheConfigBuilder::default()
}
}
#[cfg(feature = "cache")]
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct CacheStoreConfig {
#[serde(flatten)]
pub store_type: CacheStoreTypeConfig,
}
#[cfg(feature = "cache")]
impl CacheStoreConfig {
#[must_use]
pub fn builder() -> CacheStoreConfigBuilder {
CacheStoreConfigBuilder::default()
}
}
#[cfg(feature = "cache")]
impl CacheStoreConfigBuilder {
#[must_use]
pub fn build(&self) -> CacheStoreConfig {
CacheStoreConfig {
store_type: self.store_type.clone().unwrap_or_default(),
}
}
}
#[cfg(feature = "cache")]
pub(crate) const DEFAULT_REDIS_POOL_SIZE: usize = default_redis_pool_size();
#[cfg(feature = "cache")]
const fn default_redis_pool_size() -> usize {
10
}
#[expect(clippy::trivially_copy_pass_by_ref)]
#[cfg(feature = "cache")]
const fn is_default_redis_pool_size(size: &usize) -> bool {
*size == default_redis_pool_size()
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
#[cfg(feature = "cache")]
pub enum CacheStoreTypeConfig {
#[default]
Memory,
Redis {
url: CacheUrl,
#[serde(
default = "default_redis_pool_size",
skip_serializing_if = "is_default_redis_pool_size"
)]
pool_size: usize,
},
File {
path: PathBuf,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
#[non_exhaustive]
pub struct StaticFilesConfig {
#[builder(setter(into))]
pub url: String,
pub rewrite: StaticFilesPathRewriteMode,
#[serde(with = "crate::serializers::humantime")]
#[builder(setter(strip_option), default)]
pub cache_timeout: Option<Duration>,
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum StaticFilesPathRewriteMode {
#[default]
None,
QueryParam,
}
impl StaticFilesConfigBuilder {
#[must_use]
pub fn build(&self) -> StaticFilesConfig {
StaticFilesConfig {
url: self.url.clone().unwrap_or("/static/".to_string()),
rewrite: self.rewrite.clone().unwrap_or_default(),
cache_timeout: self.cache_timeout.unwrap_or_default(),
}
}
}
impl Default for StaticFilesConfig {
fn default() -> Self {
StaticFilesConfig::builder().build()
}
}
impl StaticFilesConfig {
#[must_use]
pub fn builder() -> StaticFilesConfigBuilder {
StaticFilesConfigBuilder::default()
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
#[non_exhaustive]
pub struct MiddlewareConfig {
pub live_reload: LiveReloadMiddlewareConfig,
pub session: SessionMiddlewareConfig,
}
impl MiddlewareConfig {
#[must_use]
pub fn builder() -> MiddlewareConfigBuilder {
MiddlewareConfigBuilder::default()
}
}
impl MiddlewareConfigBuilder {
#[must_use]
pub fn build(&self) -> MiddlewareConfig {
MiddlewareConfig {
live_reload: self.live_reload.clone().unwrap_or_default(),
session: self.session.clone().unwrap_or_default(),
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
#[non_exhaustive]
pub struct LiveReloadMiddlewareConfig {
pub enabled: bool,
}
impl LiveReloadMiddlewareConfig {
#[must_use]
pub fn builder() -> LiveReloadMiddlewareConfigBuilder {
LiveReloadMiddlewareConfigBuilder::default()
}
}
impl LiveReloadMiddlewareConfigBuilder {
#[must_use]
pub fn build(&self) -> LiveReloadMiddlewareConfig {
LiveReloadMiddlewareConfig {
enabled: self.enabled.unwrap_or_default(),
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum SessionStoreTypeConfig {
#[default]
Memory,
#[cfg(all(feature = "db", feature = "json"))]
Database,
#[cfg(feature = "json")]
File {
path: PathBuf,
},
#[cfg(feature = "cache")]
Cache {
uri: CacheUrl,
},
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct SessionStoreConfig {
#[serde(flatten)]
pub store_type: SessionStoreTypeConfig,
}
impl SessionStoreConfig {
#[must_use]
pub fn builder() -> SessionStoreConfigBuilder {
SessionStoreConfigBuilder::default()
}
}
impl SessionStoreConfigBuilder {
#[must_use]
pub fn build(&self) -> SessionStoreConfig {
SessionStoreConfig {
store_type: self.store_type.clone().unwrap_or_default(),
}
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum SameSite {
#[default]
Strict,
Lax,
None,
}
impl From<SameSite> for tower_sessions::cookie::SameSite {
fn from(value: SameSite) -> Self {
match value {
SameSite::Strict => Self::Strict,
SameSite::Lax => Self::Lax,
SameSite::None => Self::None,
}
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Expiry {
#[default]
OnSessionEnd,
OnInactivity(Duration),
AtDateTime(DateTime<FixedOffset>),
}
impl From<Expiry> for tower_sessions::Expiry {
fn from(value: Expiry) -> Self {
match value {
Expiry::OnSessionEnd => Self::OnSessionEnd,
Expiry::OnInactivity(duration) => {
Self::OnInactivity(time::Duration::try_from(duration).unwrap_or_else(|e| {
panic!("could not convert {duration:?} into a valid time::Duration: {e:?}",)
}))
}
Expiry::AtDateTime(time) => {
Self::AtDateTime(DateTimeWithOffsetAdapter::new(time).into_offsetdatetime())
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
#[non_exhaustive]
pub struct SessionMiddlewareConfig {
pub secure: bool,
pub http_only: bool,
pub same_site: SameSite,
#[builder(setter(strip_option), default)]
pub domain: Option<String>,
pub path: String,
pub name: String,
pub always_save: bool,
#[serde(with = "crate::serializers::session_expiry_time")]
pub expiry: Expiry,
pub store: SessionStoreConfig,
}
impl SessionMiddlewareConfig {
#[must_use]
pub fn builder() -> SessionMiddlewareConfigBuilder {
SessionMiddlewareConfigBuilder::default()
}
}
impl SessionMiddlewareConfigBuilder {
#[must_use]
pub fn build(&self) -> SessionMiddlewareConfig {
SessionMiddlewareConfig {
secure: self.secure.unwrap_or(true),
http_only: self.http_only.unwrap_or(true),
same_site: self.same_site.unwrap_or_default(),
domain: self.domain.clone().unwrap_or_default(),
name: self.name.clone().unwrap_or("id".to_string()),
path: self.path.clone().unwrap_or(String::from("/")),
always_save: self.always_save.unwrap_or(false),
expiry: self.expiry.unwrap_or_default(),
store: self.store.clone().unwrap_or_default(),
}
}
}
impl Default for SessionMiddlewareConfig {
fn default() -> Self {
SessionMiddlewareConfig::builder().build()
}
}
#[cfg(feature = "email")]
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum EmailTransportTypeConfig {
#[default]
Console,
Smtp {
url: EmailUrl,
mechanism: Mechanism,
},
}
#[cfg(feature = "email")]
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct EmailTransportConfig {
#[serde(flatten)]
pub transport_type: EmailTransportTypeConfig,
}
#[cfg(feature = "email")]
impl EmailTransportConfig {
#[must_use]
pub fn builder() -> EmailTransportConfigBuilder {
EmailTransportConfigBuilder::default()
}
}
#[cfg(feature = "email")]
impl EmailTransportConfigBuilder {
#[must_use]
pub fn build(&self) -> EmailTransportConfig {
EmailTransportConfig {
transport_type: self.transport_type.clone().unwrap_or_default(),
}
}
}
#[cfg(feature = "email")]
#[derive(Debug, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct EmailConfig {
#[builder(default)]
pub transport: EmailTransportConfig,
}
#[cfg(feature = "email")]
impl EmailConfig {
#[must_use]
pub fn builder() -> EmailConfigBuilder {
EmailConfigBuilder::default()
}
}
#[cfg(feature = "email")]
impl EmailConfigBuilder {
#[must_use]
pub fn build(&self) -> EmailConfig {
EmailConfig {
transport: self.transport.clone().unwrap_or_default(),
}
}
}
#[cfg(feature = "email")]
impl Default for EmailConfig {
fn default() -> Self {
EmailConfig::builder().build()
}
}
#[repr(transparent)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(from = "String")]
pub struct SecretKey(Box<[u8]>);
impl SecretKey {
#[must_use]
pub fn new(hash: &[u8]) -> Self {
Self(Box::from(hash))
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
#[must_use]
pub fn into_bytes(self) -> Box<[u8]> {
self.0
}
}
impl From<&[u8]> for SecretKey {
fn from(value: &[u8]) -> Self {
Self::new(value)
}
}
impl From<String> for SecretKey {
fn from(value: String) -> Self {
Self::new(value.as_bytes())
}
}
impl From<&str> for SecretKey {
fn from(value: &str) -> Self {
Self::new(value.as_bytes())
}
}
impl PartialEq for SecretKey {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl Eq for SecretKey {}
impl Debug for SecretKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SecretKey(\"**********\")")
}
}
impl Default for SecretKey {
fn default() -> Self {
Self::new(&[])
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(transparent)]
#[cfg(feature = "db")]
pub struct DatabaseUrl(url::Url);
#[cfg(feature = "db")]
impl DatabaseUrl {
#[must_use]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
#[cfg(feature = "db")]
impl From<String> for DatabaseUrl {
fn from(url: String) -> Self {
Self(url::Url::parse(&url).expect("valid URL"))
}
}
#[cfg(feature = "db")]
impl From<&str> for DatabaseUrl {
fn from(url: &str) -> Self {
Self(url::Url::parse(url).expect("valid URL"))
}
}
#[cfg(feature = "db")]
impl Debug for DatabaseUrl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let new_url = conceal_url_parts(&self.0);
f.debug_tuple("DatabaseUrl")
.field(&new_url.as_str())
.finish()
}
}
#[derive(Debug, Error, PartialEq, Eq)]
#[non_exhaustive]
pub enum ParseCacheTypeError {
#[error("unsupported cache type: `{0}`")]
Unsupported(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[cfg(feature = "cache")]
#[non_exhaustive]
pub enum CacheType {
#[cfg(feature = "redis")]
Redis,
}
#[cfg(feature = "cache")]
impl TryFrom<&str> for CacheType {
type Error = ParseCacheTypeError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
#[cfg(feature = "redis")]
"redis" => Ok(CacheType::Redis),
other => Err(ParseCacheTypeError::Unsupported(other.to_owned())),
}
}
}
#[cfg(feature = "cache")]
impl std::str::FromStr for CacheType {
type Err = ParseCacheTypeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
CacheType::try_from(s)
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(transparent)]
#[cfg(feature = "cache")]
pub struct CacheUrl(url::Url);
#[cfg(feature = "cache")]
impl CacheUrl {
#[must_use]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
#[must_use]
pub fn scheme(&self) -> &str {
self.0.scheme()
}
#[allow(clippy::allow_attributes, unused, reason = "used in tests")]
pub(crate) fn inner_mut(&mut self) -> &mut url::Url {
&mut self.0
}
}
#[cfg(feature = "cache")]
impl From<String> for CacheUrl {
fn from(url: String) -> Self {
Self(url::Url::parse(&url).expect("invalid cache URL"))
}
}
#[cfg(feature = "cache")]
impl From<&str> for CacheUrl {
fn from(url: &str) -> Self {
Self(url::Url::parse(url).expect("invalid cache URL"))
}
}
#[cfg(feature = "cache")]
impl TryFrom<CacheUrl> for CacheType {
type Error = ParseCacheTypeError;
fn try_from(value: CacheUrl) -> Result<Self, Self::Error> {
CacheType::try_from(value.0.scheme())
}
}
#[cfg(feature = "cache")]
impl Debug for CacheUrl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let new_url = conceal_url_parts(&self.0);
f.debug_tuple("CacheUrl").field(&new_url.as_str()).finish()
}
}
#[cfg(any(feature = "cache", feature = "db"))]
fn conceal_url_parts(url: &url::Url) -> url::Url {
let mut new_url = url.clone();
if !new_url.username().is_empty() {
new_url
.set_username("********")
.expect("set_username should succeed if username is present");
}
if new_url.password().is_some() {
new_url
.set_password(Some("********"))
.expect("set_password should succeed if password is present");
}
new_url
}
#[cfg(feature = "cache")]
impl std::fmt::Display for CacheUrl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.0.as_str())
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
#[cfg(feature = "email")]
pub struct EmailUrl(url::Url);
#[cfg(feature = "email")]
impl EmailUrl {
#[must_use]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
#[cfg(feature = "email")]
impl From<String> for EmailUrl {
fn from(url: String) -> Self {
Self(url::Url::parse(&url).expect("valid URL"))
}
}
#[cfg(feature = "email")]
impl From<&str> for EmailUrl {
fn from(url: &str) -> Self {
Self(url::Url::parse(url).expect("valid URL"))
}
}
#[cfg(test)]
mod tests {
use time::OffsetDateTime;
use super::*;
#[test]
fn from_toml_valid() {
let toml_content = r#"
debug = true
register_panic_hook = true
secret_key = "123abc"
fallback_secret_keys = ["456def", "789ghi"]
auth_backend = { type = "none" }
[static_files]
url = "/assets/"
rewrite = "none"
cache_timeout = "1h"
[middlewares]
live_reload.enabled = true
[middlewares.session]
secure = false
http_only = false
domain = "localhost"
path = "/some/path"
always_save = true
name = "some.sid"
"#;
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert!(config.debug);
assert!(config.register_panic_hook);
assert_eq!(config.secret_key.as_bytes(), b"123abc");
assert_eq!(config.fallback_secret_keys.len(), 2);
assert_eq!(config.fallback_secret_keys[0].as_bytes(), b"456def");
assert_eq!(config.fallback_secret_keys[1].as_bytes(), b"789ghi");
assert_eq!(config.auth_backend, AuthBackendConfig::None);
assert_eq!(config.static_files.url, "/assets/");
assert_eq!(
config.static_files.rewrite,
StaticFilesPathRewriteMode::None
);
assert_eq!(
config.static_files.cache_timeout,
Some(Duration::from_secs(3600))
);
assert!(config.middlewares.live_reload.enabled);
assert!(!config.middlewares.session.secure);
assert!(!config.middlewares.session.http_only);
assert_eq!(
config.middlewares.session.domain,
Some(String::from("localhost"))
);
assert!(config.middlewares.session.always_save);
assert_eq!(config.middlewares.session.name, String::from("some.sid"));
assert_eq!(config.middlewares.session.path, String::from("/some/path"));
}
#[test]
fn default_values_from_valid_toml() {
let toml_content = "";
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert!(config.debug);
assert!(config.register_panic_hook);
assert_eq!(config.secret_key.as_bytes(), b"");
assert_eq!(config.fallback_secret_keys.len(), 0);
assert_eq!(config.auth_backend, AuthBackendConfig::None);
assert_eq!(config.static_files.url, "/static/");
assert_eq!(
config.static_files.rewrite,
StaticFilesPathRewriteMode::None
);
assert_eq!(config.static_files.cache_timeout, None);
assert!(!config.middlewares.live_reload.enabled);
assert!(config.middlewares.session.secure);
assert!(config.middlewares.session.http_only);
assert_eq!(config.middlewares.session.domain, None);
assert!(!config.middlewares.session.always_save);
assert_eq!(config.middlewares.session.name, String::from("id"));
assert_eq!(config.middlewares.session.path, String::from("/"));
assert_eq!(config.middlewares.session.same_site, SameSite::Strict);
assert_eq!(config.middlewares.session.expiry, Expiry::OnSessionEnd);
assert_eq!(
config.middlewares.session.store.store_type,
SessionStoreTypeConfig::Memory
);
assert_eq!(config.database.url, None);
}
#[test]
fn same_site_from_valid_toml() {
let same_site_options = [
(
"none",
SameSite::None,
tower_sessions::cookie::SameSite::None,
),
("lax", SameSite::Lax, tower_sessions::cookie::SameSite::Lax),
(
"strict",
SameSite::Strict,
tower_sessions::cookie::SameSite::Strict,
),
];
for (value, expected, tower_sessions_expected) in same_site_options {
let toml_content = format!(
r#"
[middlewares.session]
same_site = "{value}"
"#
);
let config = ProjectConfig::from_toml(&toml_content).unwrap();
let actual = config.middlewares.session.same_site;
assert_eq!(actual, expected);
assert_eq!(
tower_sessions::cookie::SameSite::from(actual),
tower_sessions_expected
);
}
}
#[test]
fn expiry_from_valid_toml() {
let expiry_opts = [
(
"2h",
Expiry::OnInactivity(Duration::from_secs(7200)),
tower_sessions::Expiry::OnInactivity(time::Duration::seconds(7200)),
),
(
"2025-12-31T23:59:59+00:00",
Expiry::AtDateTime(
DateTime::parse_from_rfc3339("2025-12-31T23:59:59+00:00").unwrap(),
),
tower_sessions::Expiry::AtDateTime(OffsetDateTime::new_utc(
time::Date::from_calendar_date(2025, time::Month::December, 31).unwrap(),
time::Time::from_hms(23, 59, 59).unwrap(),
)),
),
];
for (value, expected, tower_session_expected) in expiry_opts {
let toml_content = format!(
r#"
[middlewares.session]
expiry = "{value}"
"#
);
let config = ProjectConfig::from_toml(&toml_content).unwrap();
let actual = config.middlewares.session.expiry;
assert_eq!(actual, expected);
assert_eq!(tower_sessions::Expiry::from(actual), tower_session_expected);
}
}
#[test]
fn expiry_from_invalid_toml() {
let toml_content = r#"
[middlewares.session]
expiry = "invalid time"
"#
.to_string();
let config = ProjectConfig::from_toml(&toml_content);
assert!(config.is_err());
assert!(
config
.unwrap_err()
.to_string()
.contains("could not parse the config")
);
}
#[test]
fn session_store_valid_toml() {
let toml_content = r#"
debug = true
register_panic_hook = true
secret_key = "123abc"
fallback_secret_keys = ["456def", "789ghi"]
auth_backend = { type = "none" }
[static_files]
url = "/assets/"
rewrite = "none"
cache_timeout = "1h"
[middlewares]
live_reload.enabled = true
[middlewares.session]
secure = false
"#;
let store_configs = [
(
r#"
[middlewares.session.store]
type = "memory"
"#,
SessionStoreTypeConfig::Memory,
),
#[cfg(feature = "cache")]
(
r#"
[middlewares.session.store]
type = "cache"
uri = "redis://redis"
"#,
SessionStoreTypeConfig::Cache {
uri: CacheUrl::from("redis://redis"),
},
),
(
r#"
[middlewares.session.store]
type = "file"
path = "session/path/"
"#,
SessionStoreTypeConfig::File {
path: PathBuf::from("session/path"),
},
),
#[cfg(all(feature = "db", feature = "json"))]
(
r#"
[middlewares.session.store]
type = "database"
"#,
SessionStoreTypeConfig::Database,
),
];
for (cfg_toml, cfg_type) in store_configs {
let full_cfg_str = format!("{toml_content}\n{cfg_toml}");
let config = ProjectConfig::from_toml(&full_cfg_str).unwrap();
assert_eq!(config.middlewares.session.store.store_type, cfg_type);
}
}
#[test]
fn from_toml_invalid() {
let toml_content = r"
debug = true
secret_key = 123abc
";
let result = ProjectConfig::from_toml(toml_content);
assert!(result.is_err());
}
#[test]
fn from_toml_missing_fields() {
let toml_content = r#"
secret_key = "123abc"
[static_files]
rewrite = "query_param"
"#;
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert_eq!(config.debug, cfg!(debug_assertions));
assert_eq!(config.secret_key.as_bytes(), b"123abc");
assert_eq!(config.static_files.url, "/static/");
assert_eq!(
config.static_files.rewrite,
StaticFilesPathRewriteMode::QueryParam
);
}
#[test]
#[cfg(feature = "redis")]
fn cache_type_from_str_redis() {
assert_eq!(CacheType::try_from("redis").unwrap(), CacheType::Redis);
}
#[test]
#[cfg(feature = "cache")]
fn cache_type_from_str_unknown() {
for &s in &["", "foo", "redis://foo"] {
assert_eq!(
CacheType::try_from(s),
Err(ParseCacheTypeError::Unsupported(s.to_owned()))
);
}
}
#[test]
#[cfg(feature = "redis")]
fn cache_type_from_cacheurl() {
let url = CacheUrl::from("redis://localhost/");
assert_eq!(CacheType::try_from(url.clone()).unwrap(), CacheType::Redis);
let other = CacheUrl::from("http://example.com/");
assert_eq!(
CacheType::try_from(other),
Err(ParseCacheTypeError::Unsupported("http".to_string()))
);
}
#[test]
#[cfg(feature = "cache")]
fn cacheurl_from_str_and_string() {
let s = "http://example.com/foo";
let u1 = CacheUrl::from(s);
let u2 = CacheUrl::from(s.to_string());
assert_eq!(u1, u2);
assert_eq!(u1.as_str(), s);
}
#[test]
#[cfg(feature = "cache")]
#[should_panic(expected = "invalid cache URL")]
fn cacheurl_from_invalid_str_panics() {
let _ = CacheUrl::from("not a url");
}
#[test]
#[cfg(feature = "cache")]
fn cacheurl_as_str_roundtrip() {
let raw = "https://user:pass@host:1234/path?query#frag";
let cu = CacheUrl::from(raw);
assert_eq!(cu.as_str(), url::Url::parse(raw).unwrap().as_str());
}
#[test]
#[cfg(feature = "cache")]
fn cacheurl_debug_masks_credentials() {
let raw = "https://user:secret@host:1234/path";
let cu = CacheUrl::from(raw);
let dbg = format!("{cu:?}");
assert!(dbg.starts_with("CacheUrl(\"https://********:********@host:1234/path\")"));
}
#[test]
fn conceal_url_details_leaves_no_credentials() {
let raw = "ftp://alice:alicepwd@server/";
let parsed = url::Url::parse(raw).unwrap();
let concealed = conceal_url_parts(&parsed);
assert_eq!(concealed.username(), "********");
assert_eq!(concealed.password(), Some("********"));
}
#[test]
#[cfg(feature = "cache")]
fn cache_config_from_toml_memory() {
let toml_content = r#"
[cache]
max_retries = 5
timeout = "60s"
prefix = "v1"
[cache.store]
type = "memory"
"#;
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert_eq!(config.cache.max_retries, 5);
assert_eq!(
config.cache.timeout,
Timeout::After(Duration::from_secs(60))
);
assert_eq!(config.cache.prefix, Some("v1".to_string()));
assert_eq!(config.cache.store.store_type, CacheStoreTypeConfig::Memory);
}
#[test]
#[cfg(feature = "cache")]
fn cache_config_from_toml_redis() {
macro_rules! cache_toml_with_pool {
() => {
r#"
[cache]
max_retries = 10
timeout = "120s"
[cache.store]
type = "redis"
url = "redis://localhost:6379"
pool_size = 20
"#
};
}
macro_rules! cache_toml_without_pool {
() => {
r#"
[cache]
max_retries = 10
timeout = "120s"
[cache.store]
type = "redis"
url = "redis://localhost:6379"
"#
};
}
let variants: [(&str, usize); 2] = [
(cache_toml_with_pool!(), 20),
(cache_toml_without_pool!(), default_redis_pool_size()),
];
for (toml_content, expected_size) in variants {
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert_eq!(config.cache.max_retries, 10);
assert_eq!(
config.cache.timeout,
Timeout::After(Duration::from_secs(120))
);
assert_eq!(config.cache.prefix, None);
if let CacheStoreTypeConfig::Redis { url, pool_size } = config.cache.store.store_type {
assert_eq!(url.as_str(), "redis://localhost:6379");
assert_eq!(pool_size, expected_size);
}
}
}
#[test]
#[cfg(feature = "cache")]
fn cache_config_from_toml_file() {
let toml_content = r#"
[cache]
max_retries = 3
timeout = "30s"
prefix = "dev"
[cache.store]
type = "file"
path = "/tmp/cache"
"#;
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert_eq!(config.cache.max_retries, 3);
assert_eq!(
config.cache.timeout,
Timeout::After(Duration::from_secs(30))
);
assert_eq!(config.cache.prefix, Some("dev".to_string()));
if let CacheStoreTypeConfig::File { path } = &config.cache.store.store_type {
assert_eq!(path, &PathBuf::from("/tmp/cache"));
}
}
#[test]
#[cfg(feature = "cache")]
fn cache_config_defaults() {
let toml_content = r"
[cache]
";
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert_eq!(config.cache.max_retries, 3);
assert_eq!(config.cache.timeout, Timeout::default());
assert_eq!(config.cache.prefix, None);
assert_eq!(config.cache.store.store_type, CacheStoreTypeConfig::Memory);
}
#[test]
#[cfg(feature = "cache")]
fn test_is_default_redis_pool_size() {
assert!(is_default_redis_pool_size(&10));
}
#[test]
#[cfg(feature = "cache")]
fn cache_config_builder() {
let config = CacheConfig::builder()
.max_retries(7)
.timeout(Timeout::After(Duration::from_secs(90)))
.prefix("v2".to_string())
.store(CacheStoreConfig {
store_type: CacheStoreTypeConfig::Memory,
})
.build();
assert_eq!(config.max_retries, 7);
assert_eq!(config.timeout, Timeout::After(Duration::from_secs(90)));
assert_eq!(config.prefix, Some("v2".to_string()));
assert_eq!(config.store.store_type, CacheStoreTypeConfig::Memory);
}
#[test]
#[cfg(feature = "cache")]
fn cache_config_builder_defaults() {
let config = CacheConfig::builder().build();
assert_eq!(config.max_retries, 3);
assert_eq!(config.timeout, Timeout::default());
assert_eq!(config.prefix, None);
assert_eq!(config.store.store_type, CacheStoreTypeConfig::Memory);
}
#[test]
#[cfg(feature = "cache")]
fn cache_store_config_builder() {
let config = CacheStoreConfig {
store_type: CacheStoreTypeConfig::Redis {
url: CacheUrl::from("redis://localhost:6379"),
pool_size: 15,
},
};
if let CacheStoreTypeConfig::Redis { url, pool_size } = config.store_type {
assert_eq!(url.as_str(), "redis://localhost:6379");
assert_eq!(pool_size, 15);
}
}
#[test]
#[cfg(feature = "cache")]
fn cache_store_config_default() {
let config = CacheStoreConfig::default();
assert_eq!(config.store_type, CacheStoreTypeConfig::Memory);
}
#[test]
fn never_is_never_expired() {
let now_fixed: DateTime<FixedOffset> =
Utc::now().with_timezone(&FixedOffset::east_opt(0).unwrap());
assert!(!Timeout::Never.is_expired(Some(now_fixed)));
assert!(!Timeout::Never.is_expired(None));
}
#[test]
fn after_is_expired_based_on_insertion_offset() {
let offset = FixedOffset::east_opt(3600).unwrap();
let insertion_time: DateTime<FixedOffset> =
(Utc::now() - chrono::Duration::hours(1)).with_timezone(&offset);
let timeout = Timeout::After(Duration::from_secs(60)); assert!(timeout.is_expired(Some(insertion_time)));
}
#[test]
fn after_is_not_expired_when_not_yet_passed_with_offset() {
let offset = FixedOffset::east_opt(-2 * 3600).unwrap();
let insertion_time: DateTime<FixedOffset> =
(Utc::now() - chrono::Duration::seconds(10)).with_timezone(&offset);
let timeout = Timeout::After(Duration::from_secs(60));
assert!(!timeout.is_expired(Some(insertion_time)));
}
#[test]
#[should_panic(expected = "insertion_time is required for Timeout::After expiry check")]
fn after_is_expired_panics_with_no_insertion_time() {
let timeout = Timeout::After(Duration::from_secs(60));
let _ = timeout.is_expired(None);
}
#[test]
fn atdatetime_respects_stored_offset_when_comparing() {
let offset = FixedOffset::east_opt(3600).unwrap();
let past: DateTime<FixedOffset> =
(Utc::now() - chrono::Duration::seconds(60)).with_timezone(&offset);
let future: DateTime<FixedOffset> =
(Utc::now() + chrono::Duration::seconds(60)).with_timezone(&offset);
assert!(Timeout::AtDateTime(past).is_expired(None));
assert!(!Timeout::AtDateTime(future).is_expired(None));
}
#[test]
fn canonicalize_after_produces_atdatetime_in_utc_offset_zero() {
let before = Utc::now().with_timezone(&FixedOffset::east_opt(0).unwrap());
let duration = Duration::from_secs(2);
let canon = Timeout::After(duration).canonicalize();
match canon {
Timeout::AtDateTime(dt) => {
assert_eq!(dt.offset().local_minus_utc(), 0);
assert!(dt >= before);
let max_allowed = before
+ chrono::Duration::from_std(duration).unwrap()
+ chrono::Duration::seconds(1);
assert!(
dt <= max_allowed,
"canonicalized datetime is unexpectedly far ahead"
);
}
other => panic!("expected AtDateTime, got {other:?}"),
}
}
#[test]
fn canonicalize_preserves_atdatetime_and_never() {
let dt: DateTime<FixedOffset> = (Utc::now() + chrono::Duration::seconds(10))
.with_timezone(&FixedOffset::east_opt(0).unwrap());
let t = Timeout::AtDateTime(dt);
assert_eq!(t.canonicalize(), t);
let never = Timeout::Never;
assert_eq!(never.canonicalize(), Timeout::Never);
}
#[test]
#[cfg(feature = "email")]
fn email_config_from_toml_console() {
let toml_content = r#"
[email]
type = "console"
"#;
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert_eq!(
config.email.transport.transport_type,
EmailTransportTypeConfig::Console
);
}
#[test]
#[cfg(feature = "email")]
fn email_config_from_toml_smtp() {
let toml_content = r#"
[email.transport]
type = "smtp"
url = "smtp://user:pass@hostname:587"
mechanism = "plain"
"#;
let config = ProjectConfig::from_toml(toml_content).unwrap();
if let EmailTransportTypeConfig::Smtp { url, mechanism } =
&config.email.transport.transport_type
{
assert_eq!(url.as_str(), "smtp://user:pass@hostname:587");
assert_eq!(*mechanism, Mechanism::Plain);
}
}
#[test]
#[cfg(feature = "email")]
fn email_config_builder_defaults() {
let config = EmailConfig::builder().build();
assert_eq!(
config.transport.transport_type,
EmailTransportTypeConfig::Console
);
}
#[test]
#[cfg(feature = "email")]
fn email_url_from_str_and_string() {
let s = "smtp://user:pass@hostname:587";
let u1 = EmailUrl::from(s);
let u2 = EmailUrl::from(s.to_string());
assert_eq!(u1, u2);
assert_eq!(u1.as_str(), s);
}
}