#![allow(missing_copy_implementations)]
use std::time::Duration;
use derive_builder::Builder;
use derive_more::with_trait::{Debug, From};
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
#[derive(Debug, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct ProjectConfig {
pub debug: bool,
pub register_panic_hook: bool,
pub secret_key: SecretKey,
pub fallback_secret_keys: Vec<SecretKey>,
pub auth_backend: AuthBackendConfig,
#[cfg(feature = "db")]
pub database: DatabaseConfig,
pub static_files: StaticFilesConfig,
pub middlewares: MiddlewareConfig,
}
const fn default_debug() -> bool {
cfg!(debug_assertions)
}
impl Default for ProjectConfig {
fn default() -> Self {
ProjectConfig::builder().build()
}
}
impl ProjectConfig {
#[must_use]
pub fn builder() -> ProjectConfigBuilder {
ProjectConfigBuilder::default()
}
#[must_use]
pub fn dev_default() -> ProjectConfig {
let mut builder = ProjectConfig::builder();
builder.debug(true).register_panic_hook(true);
#[cfg(feature = "db")]
builder.database(DatabaseConfig::builder().url("sqlite::memory:").build());
builder.build()
}
pub fn from_toml(toml_content: &str) -> crate::Result<ProjectConfig> {
let config: ProjectConfig = toml::from_str(toml_content)?;
Ok(config)
}
}
impl ProjectConfigBuilder {
#[must_use]
pub fn build(&self) -> ProjectConfig {
let debug = self.debug.unwrap_or(default_debug());
ProjectConfig {
debug,
register_panic_hook: self.register_panic_hook.unwrap_or(true),
secret_key: self.secret_key.clone().unwrap_or_default(),
fallback_secret_keys: self.fallback_secret_keys.clone().unwrap_or_default(),
auth_backend: self.auth_backend.unwrap_or_default(),
#[cfg(feature = "db")]
database: self.database.clone().unwrap_or_default(),
static_files: self.static_files.clone().unwrap_or_default(),
middlewares: self.middlewares.clone().unwrap_or_default(),
}
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuthBackendConfig {
#[default]
None,
#[cfg(feature = "db")]
Database,
}
#[cfg(feature = "db")]
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct DatabaseConfig {
#[builder(setter(into, strip_option), default)]
pub url: Option<DatabaseUrl>,
}
#[cfg(feature = "db")]
impl DatabaseConfigBuilder {
#[must_use]
pub fn build(&self) -> DatabaseConfig {
DatabaseConfig {
url: self.url.clone().expect("Database URL is required"),
}
}
}
#[cfg(feature = "db")]
impl DatabaseConfig {
#[must_use]
pub fn builder() -> DatabaseConfigBuilder {
DatabaseConfigBuilder::default()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct StaticFilesConfig {
#[builder(setter(into))]
pub url: String,
pub rewrite: StaticFilesPathRewriteMode,
#[serde(with = "crate::serializers::humantime")]
#[builder(setter(strip_option), default)]
pub cache_timeout: Option<Duration>,
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum StaticFilesPathRewriteMode {
#[default]
None,
QueryParam,
}
impl StaticFilesConfigBuilder {
#[must_use]
pub fn build(&self) -> StaticFilesConfig {
StaticFilesConfig {
url: self.url.clone().unwrap_or("/static/".to_string()),
rewrite: self.rewrite.clone().unwrap_or_default(),
cache_timeout: self.cache_timeout.unwrap_or_default(),
}
}
}
impl Default for StaticFilesConfig {
fn default() -> Self {
StaticFilesConfig::builder().build()
}
}
impl StaticFilesConfig {
#[must_use]
pub fn builder() -> StaticFilesConfigBuilder {
StaticFilesConfigBuilder::default()
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct MiddlewareConfig {
pub live_reload: LiveReloadMiddlewareConfig,
pub session: SessionMiddlewareConfig,
}
impl MiddlewareConfig {
#[must_use]
pub fn builder() -> MiddlewareConfigBuilder {
MiddlewareConfigBuilder::default()
}
}
impl MiddlewareConfigBuilder {
#[must_use]
pub fn build(&self) -> MiddlewareConfig {
MiddlewareConfig {
live_reload: self.live_reload.clone().unwrap_or_default(),
session: self.session.clone().unwrap_or_default(),
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct LiveReloadMiddlewareConfig {
pub enabled: bool,
}
impl LiveReloadMiddlewareConfig {
#[must_use]
pub fn builder() -> LiveReloadMiddlewareConfigBuilder {
LiveReloadMiddlewareConfigBuilder::default()
}
}
impl LiveReloadMiddlewareConfigBuilder {
#[must_use]
pub fn build(&self) -> LiveReloadMiddlewareConfig {
LiveReloadMiddlewareConfig {
enabled: self.enabled.unwrap_or_default(),
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Builder, Serialize, Deserialize)]
#[builder(build_fn(skip, error = std::convert::Infallible))]
#[serde(default)]
pub struct SessionMiddlewareConfig {
pub secure: bool,
}
impl SessionMiddlewareConfig {
#[must_use]
pub fn builder() -> SessionMiddlewareConfigBuilder {
SessionMiddlewareConfigBuilder::default()
}
}
impl SessionMiddlewareConfigBuilder {
#[must_use]
pub fn build(&self) -> SessionMiddlewareConfig {
SessionMiddlewareConfig {
secure: self.secure.unwrap_or(true),
}
}
}
#[repr(transparent)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(from = "String")]
pub struct SecretKey(Box<[u8]>);
impl SecretKey {
#[must_use]
pub fn new(hash: &[u8]) -> Self {
Self(Box::from(hash))
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
#[must_use]
pub fn into_bytes(self) -> Box<[u8]> {
self.0
}
}
impl From<&[u8]> for SecretKey {
fn from(value: &[u8]) -> Self {
Self::new(value)
}
}
impl From<String> for SecretKey {
fn from(value: String) -> Self {
Self::new(value.as_bytes())
}
}
impl From<&str> for SecretKey {
fn from(value: &str) -> Self {
Self::new(value.as_bytes())
}
}
impl PartialEq for SecretKey {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl Eq for SecretKey {}
impl Debug for SecretKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SecretKey(\"**********\")")
}
}
impl Default for SecretKey {
fn default() -> Self {
Self::new(&[])
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(transparent)]
#[cfg(feature = "db")]
pub struct DatabaseUrl(url::Url);
#[cfg(feature = "db")]
impl DatabaseUrl {
#[must_use]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
#[cfg(feature = "db")]
impl From<String> for DatabaseUrl {
fn from(url: String) -> Self {
Self(url::Url::parse(&url).expect("valid URL"))
}
}
#[cfg(feature = "db")]
impl From<&str> for DatabaseUrl {
fn from(url: &str) -> Self {
Self(url::Url::parse(url).expect("valid URL"))
}
}
#[cfg(feature = "db")]
impl Debug for DatabaseUrl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut new_url = self.0.clone();
if !new_url.username().is_empty() {
new_url
.set_username("********")
.expect("set_username should succeed if username is present");
}
if new_url.password().is_some() {
new_url
.set_password(Some("********"))
.expect("set_password should succeed if password is present");
}
f.debug_tuple("DatabaseUrl")
.field(&new_url.as_str())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_toml_valid() {
let toml_content = r#"
debug = true
register_panic_hook = true
secret_key = "123abc"
fallback_secret_keys = ["456def", "789ghi"]
auth_backend = { type = "none" }
[static_files]
url = "/assets/"
rewrite = "none"
cache_timeout = "1h"
[middlewares]
live_reload.enabled = true
[middlewares.session]
secure = false
"#;
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert!(config.debug);
assert!(config.register_panic_hook);
assert_eq!(config.secret_key.as_bytes(), b"123abc");
assert_eq!(config.fallback_secret_keys.len(), 2);
assert_eq!(config.fallback_secret_keys[0].as_bytes(), b"456def");
assert_eq!(config.fallback_secret_keys[1].as_bytes(), b"789ghi");
assert_eq!(config.auth_backend, AuthBackendConfig::None);
assert_eq!(config.static_files.url, "/assets/");
assert_eq!(
config.static_files.rewrite,
StaticFilesPathRewriteMode::None
);
assert_eq!(
config.static_files.cache_timeout,
Some(Duration::from_secs(3600))
);
assert!(config.middlewares.live_reload.enabled);
assert!(!config.middlewares.session.secure);
}
#[test]
fn from_toml_invalid() {
let toml_content = r"
debug = true
secret_key = 123abc
";
let result = ProjectConfig::from_toml(toml_content);
assert!(result.is_err());
}
#[test]
fn from_toml_missing_fields() {
let toml_content = r#"
secret_key = "123abc"
[static_files]
rewrite = "query_param"
"#;
let config = ProjectConfig::from_toml(toml_content).unwrap();
assert_eq!(config.debug, cfg!(debug_assertions));
assert_eq!(config.secret_key.as_bytes(), b"123abc");
assert_eq!(config.static_files.url, "/static/");
assert_eq!(
config.static_files.rewrite,
StaticFilesPathRewriteMode::QueryParam
);
}
}