use crate::parse::{parse_duration, parse_size};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "backend")]
pub enum StorageConfig {
#[serde(rename = "filesystem")]
Filesystem { path: String },
#[serde(rename = "postgres")]
Postgres { connection_string: String },
#[serde(rename = "amaters")]
AmateRS {
endpoints: Vec<String>,
replication_factor: usize,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProcessorConfig {
pub name: String,
pub state: String,
pub mailets: Vec<MailetConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MailetConfig {
pub matcher: String,
pub mailet: String,
#[serde(default)]
pub params: HashMap<String, String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "backend")]
pub enum AuthConfig {
#[serde(rename = "file")]
File {
#[serde(flatten)]
config: FileAuthConfig,
},
#[serde(rename = "ldap")]
Ldap {
#[serde(flatten)]
config: LdapAuthConfig,
},
#[serde(rename = "sql")]
Sql {
#[serde(flatten)]
config: SqlAuthConfig,
},
#[serde(rename = "oauth2")]
OAuth2 {
#[serde(flatten)]
config: OAuth2AuthConfig,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FileAuthConfig {
pub path: String,
#[serde(default = "default_hash_algorithm")]
pub hash_algorithm: String,
}
fn default_hash_algorithm() -> String {
"bcrypt".to_string()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LdapAuthConfig {
pub url: String,
pub base_dn: String,
pub bind_dn: String,
pub bind_password: String,
pub user_filter: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SqlAuthConfig {
pub connection_string: String,
pub query: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OAuth2AuthConfig {
pub client_id: String,
pub client_secret: String,
pub token_url: String,
pub authorization_url: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LoggingConfig {
pub level: String,
pub format: String,
pub output: String,
#[serde(default)]
pub file: Option<LogFileConfig>,
}
impl LoggingConfig {
pub fn validate_level(&self) -> anyhow::Result<()> {
match self.level.as_str() {
"trace" | "debug" | "info" | "warn" | "error" => Ok(()),
_ => Err(anyhow::anyhow!("Invalid log level: {}", self.level)),
}
}
pub fn validate_format(&self) -> anyhow::Result<()> {
match self.format.as_str() {
"json" | "text" => Ok(()),
_ => Err(anyhow::anyhow!("Invalid log format: {}", self.format)),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LogFileConfig {
pub path: String,
pub max_size: String,
pub max_backups: u32,
pub compress: bool,
}
impl LogFileConfig {
pub fn max_size_bytes(&self) -> anyhow::Result<usize> {
parse_size(&self.max_size)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct QueueConfig {
pub initial_delay: String,
pub max_delay: String,
pub backoff_multiplier: f64,
pub max_attempts: u32,
pub worker_threads: usize,
pub batch_size: usize,
}
impl QueueConfig {
pub fn initial_delay_seconds(&self) -> anyhow::Result<u64> {
parse_duration(&self.initial_delay)
}
pub fn max_delay_seconds(&self) -> anyhow::Result<u64> {
parse_duration(&self.max_delay)
}
pub fn validate_backoff_multiplier(&self) -> anyhow::Result<()> {
if self.backoff_multiplier <= 0.0 {
return Err(anyhow::anyhow!("backoff_multiplier must be positive"));
}
Ok(())
}
pub fn validate_worker_threads(&self) -> anyhow::Result<()> {
if self.worker_threads == 0 {
return Err(anyhow::anyhow!("worker_threads must be greater than 0"));
}
Ok(())
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SecurityConfig {
pub relay_networks: Vec<String>,
pub blocked_ips: Vec<String>,
pub check_recipient_exists: bool,
pub reject_unknown_recipients: bool,
}
impl SecurityConfig {
pub fn validate_relay_networks(&self) -> anyhow::Result<()> {
for network in &self.relay_networks {
if !network.contains('/') {
return Err(anyhow::anyhow!("Invalid CIDR notation: {}", network));
}
}
Ok(())
}
pub fn validate_blocked_ips(&self) -> anyhow::Result<()> {
for ip in &self.blocked_ips {
if !ip.contains('.') && !ip.contains(':') {
return Err(anyhow::anyhow!("Invalid IP address: {}", ip));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DomainsConfig {
pub local_domains: Vec<String>,
#[serde(default)]
pub aliases: HashMap<String, String>,
}
impl DomainsConfig {
pub fn validate_local_domains(&self) -> anyhow::Result<()> {
for domain in &self.local_domains {
if domain.is_empty() {
return Err(anyhow::anyhow!("Domain name cannot be empty"));
}
if !domain.contains('.') {
return Err(anyhow::anyhow!("Invalid domain name: {}", domain));
}
}
Ok(())
}
pub fn validate_aliases(&self) -> anyhow::Result<()> {
for (from, to) in &self.aliases {
if !from.contains('@') {
return Err(anyhow::anyhow!("Invalid alias source: {}", from));
}
if !to.contains('@') {
return Err(anyhow::anyhow!("Invalid alias destination: {}", to));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MetricsConfig {
pub enabled: bool,
pub bind_address: String,
pub path: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub basic_auth: Option<MetricsBasicAuthConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MetricsBasicAuthConfig {
pub username: String,
pub password_hash: String,
}
impl MetricsConfig {
pub fn validate_bind_address(&self) -> anyhow::Result<()> {
if !self.bind_address.contains(':') {
return Err(anyhow::anyhow!(
"Invalid bind address format: {}",
self.bind_address
));
}
Ok(())
}
pub fn validate_path(&self) -> anyhow::Result<()> {
if !self.path.starts_with('/') {
return Err(anyhow::anyhow!(
"Metrics path must start with '/': {}",
self.path
));
}
Ok(())
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TracingConfig {
pub enabled: bool,
pub endpoint: String,
pub protocol: OtlpProtocol,
pub service_name: String,
#[serde(default)]
pub sample_ratio: f64,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum OtlpProtocol {
Grpc,
Http,
}
impl Default for TracingConfig {
fn default() -> Self {
Self {
enabled: false,
endpoint: "http://localhost:4317".to_string(),
protocol: OtlpProtocol::Grpc,
service_name: "rusmes".to_string(),
sample_ratio: 1.0,
}
}
}
impl TracingConfig {
pub fn validate_endpoint(&self) -> anyhow::Result<()> {
if !self.endpoint.starts_with("http://") && !self.endpoint.starts_with("https://") {
return Err(anyhow::anyhow!(
"Endpoint must start with http:// or https://: {}",
self.endpoint
));
}
Ok(())
}
pub fn validate_sample_ratio(&self) -> anyhow::Result<()> {
if !(0.0..=1.0).contains(&self.sample_ratio) {
return Err(anyhow::anyhow!(
"Sample ratio must be between 0.0 and 1.0: {}",
self.sample_ratio
));
}
Ok(())
}
pub fn validate_service_name(&self) -> anyhow::Result<()> {
if self.service_name.trim().is_empty() {
return Err(anyhow::anyhow!("Service name cannot be empty"));
}
Ok(())
}
}