use std::collections::HashSet;
use std::net::{Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use serde::Deserialize;
use crate::error::{ConfigError, Result, ValidationError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum BlocklistFormat {
#[default]
Domains,
Hosts,
Adblock,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum BlocklistSourceType {
File {
path: PathBuf,
},
Remote {
url: String,
},
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct BlocklistSourceConfig {
pub name: String,
#[serde(default = "default_enabled")]
pub enabled: bool,
pub source: BlocklistSourceType,
#[serde(default)]
pub format: BlocklistFormat,
pub refresh_interval_hours: Option<u64>,
}
const fn default_enabled() -> bool {
true
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config {
pub interface: Option<String>,
#[serde(deserialize_with = "deserialize_socket_addr")]
pub upstream_resolver: SocketAddr,
#[serde(default = "default_cache_ttl")]
pub cache_ttl_seconds: u64,
#[serde(default)]
pub blocklist: Vec<String>,
#[serde(default)]
pub blocklist_sources: Vec<BlocklistSourceConfig>,
pub blocklist_cache_dir: Option<PathBuf>,
#[serde(default = "default_buffer_pool_size")]
pub buffer_pool_size: usize,
#[serde(default = "default_channel_capacity")]
pub channel_capacity: usize,
#[serde(default)]
pub arp_spoof: ArpSpoofSettings,
#[serde(default)]
pub metrics: MetricsConfig,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ArpSpoofSettings {
#[serde(default)]
pub enabled: bool,
pub gateway_ip: Option<Ipv4Addr>,
#[serde(default = "default_spoof_interval")]
pub spoof_interval_secs: u64,
#[serde(default = "default_restore_on_shutdown")]
pub restore_on_shutdown: bool,
#[serde(default = "default_forward_traffic")]
pub forward_traffic: bool,
}
impl Default for ArpSpoofSettings {
fn default() -> Self {
Self {
enabled: false,
gateway_ip: None,
spoof_interval_secs: default_spoof_interval(),
restore_on_shutdown: default_restore_on_shutdown(),
forward_traffic: default_forward_traffic(),
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct MetricsConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_metrics_listen")]
pub listen: SocketAddr,
}
impl Default for MetricsConfig {
fn default() -> Self {
Self {
enabled: false,
listen: default_metrics_listen(),
}
}
}
const fn default_cache_ttl() -> u64 {
300
}
const fn default_buffer_pool_size() -> usize {
64
}
const fn default_channel_capacity() -> usize {
1000
}
const fn default_spoof_interval() -> u64 {
2
}
const fn default_restore_on_shutdown() -> bool {
true
}
const fn default_forward_traffic() -> bool {
true
}
fn default_metrics_listen() -> SocketAddr {
SocketAddr::from(([0, 0, 0, 0], 9090))
}
fn deserialize_socket_addr<'de, D>(deserializer: D) -> std::result::Result<SocketAddr, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
impl Config {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let content = std::fs::read_to_string(path).map_err(ConfigError::ReadFile)?;
Self::parse(&content)
}
pub fn parse(content: &str) -> Result<Self> {
let config: Self = toml::from_str(content).map_err(ConfigError::Parse)?;
config.validate()?;
Ok(config)
}
#[must_use]
pub fn blocklist_cache_dir(&self) -> PathBuf {
self.blocklist_cache_dir
.clone()
.unwrap_or_else(crate::blocklist::remote::default_cache_dir)
}
fn validate(&self) -> Result<()> {
if self.cache_ttl_seconds == 0 {
return Err(ConfigError::from(ValidationError::ZeroCacheTtl).into());
}
if self.buffer_pool_size == 0 {
return Err(ConfigError::from(ValidationError::ZeroBufferPoolSize).into());
}
if self.channel_capacity == 0 {
return Err(ConfigError::from(ValidationError::ZeroChannelCapacity).into());
}
if self.arp_spoof.spoof_interval_secs == 0 {
return Err(ConfigError::from(ValidationError::ZeroSpoofInterval).into());
}
for pattern in &self.blocklist {
if pattern.is_empty() {
return Err(ConfigError::from(ValidationError::EmptyBlocklistPattern).into());
}
if pattern.starts_with("*.") && pattern.len() <= 2 {
return Err(ConfigError::from(ValidationError::InvalidWildcardPattern {
pattern: pattern.clone(),
})
.into());
}
}
self.validate_blocklist_sources()?;
Ok(())
}
fn validate_blocklist_sources(&self) -> Result<()> {
let mut seen_names = HashSet::new();
for source in &self.blocklist_sources {
if source.name.is_empty() {
return Err(ConfigError::from(ValidationError::EmptyBlocklistSourceName).into());
}
if !seen_names.insert(&source.name) {
return Err(
ConfigError::from(ValidationError::DuplicateBlocklistSourceName {
name: source.name.clone(),
})
.into(),
);
}
match &source.source {
BlocklistSourceType::File { path } => {
if path.as_os_str().is_empty() {
return Err(
ConfigError::from(ValidationError::EmptyBlocklistSourcePath {
name: source.name.clone(),
})
.into(),
);
}
if source.refresh_interval_hours.is_some() {
tracing::warn!(
name = ?source.name,
"refresh_interval_hours is ignored for file sources"
);
}
}
BlocklistSourceType::Remote { url } => {
if url.is_empty() {
return Err(ConfigError::from(ValidationError::EmptyBlocklistSourceUrl {
name: source.name.clone(),
})
.into());
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(ConfigError::from(
ValidationError::InvalidBlocklistSourceUrl {
name: source.name.clone(),
url: url.clone(),
},
)
.into());
}
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_parse_valid_config() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
cache_ttl_seconds = 600
blocklist = ["example.com", "*.ads.com"]
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(config.upstream_resolver.to_string(), "1.1.1.1:53");
assert_eq!(config.cache_ttl_seconds, 600);
assert_eq!(config.blocklist.len(), 2);
assert!(config.interface.is_none());
}
#[test]
fn should_parse_config_with_interface() {
let toml = r#"
interface = "eth0"
upstream_resolver = "8.8.8.8:53"
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(config.interface.as_deref(), Some("eth0"));
}
#[test]
fn should_use_default_values_when_not_specified() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(config.cache_ttl_seconds, 300);
assert_eq!(config.buffer_pool_size, 64);
assert_eq!(config.channel_capacity, 1000);
assert!(config.blocklist.is_empty());
assert!(!config.arp_spoof.enabled);
}
#[test]
fn should_parse_arp_spoof_config() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[arp_spoof]
enabled = true
gateway_ip = "192.168.1.1"
spoof_interval_secs = 5
restore_on_shutdown = true
forward_traffic = true
"#;
let config = Config::parse(toml).unwrap();
assert!(config.arp_spoof.enabled);
assert_eq!(
config.arp_spoof.gateway_ip,
Some(Ipv4Addr::new(192, 168, 1, 1))
);
assert_eq!(config.arp_spoof.spoof_interval_secs, 5);
assert!(config.arp_spoof.restore_on_shutdown);
assert!(config.arp_spoof.forward_traffic);
}
#[test]
fn should_use_arp_spoof_defaults_when_not_specified() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[arp_spoof]
enabled = true
"#;
let config = Config::parse(toml).unwrap();
assert!(config.arp_spoof.enabled);
assert!(config.arp_spoof.gateway_ip.is_none());
assert_eq!(config.arp_spoof.spoof_interval_secs, 2);
assert!(config.arp_spoof.restore_on_shutdown);
assert!(config.arp_spoof.forward_traffic);
}
#[test]
fn should_reject_invalid_resolver_address() {
let toml = r#"
upstream_resolver = "not-an-address"
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_zero_cache_ttl() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
cache_ttl_seconds = 0
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_empty_blocklist_pattern() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
blocklist = ["example.com", ""]
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_unknown_field() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
unknown_field = "value"
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_zero_spoof_interval() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[arp_spoof]
enabled = true
spoof_interval_secs = 0
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_parse_blocklist_source_file() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "local-custom"
enabled = true
source = { type = "file", path = "/etc/bluebox/custom-blocklist.txt" }
format = "domains"
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(config.blocklist_sources.len(), 1);
let source = &config.blocklist_sources[0];
assert_eq!(source.name, "local-custom");
assert!(source.enabled);
assert_eq!(source.format, BlocklistFormat::Domains);
assert!(source.refresh_interval_hours.is_none());
match &source.source {
BlocklistSourceType::File { path } => {
assert_eq!(path.to_str().unwrap(), "/etc/bluebox/custom-blocklist.txt");
}
BlocklistSourceType::Remote { .. } => panic!("expected file source"),
}
}
#[test]
fn should_parse_blocklist_source_remote() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "steven-black-hosts"
enabled = true
source = { type = "remote", url = "https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts" }
format = "hosts"
refresh_interval_hours = 24
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(config.blocklist_sources.len(), 1);
let source = &config.blocklist_sources[0];
assert_eq!(source.name, "steven-black-hosts");
assert!(source.enabled);
assert_eq!(source.format, BlocklistFormat::Hosts);
assert_eq!(source.refresh_interval_hours, Some(24));
match &source.source {
BlocklistSourceType::Remote { url } => {
assert_eq!(
url,
"https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts"
);
}
BlocklistSourceType::File { .. } => panic!("expected remote source"),
}
}
#[test]
fn should_use_blocklist_source_defaults() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "test"
source = { type = "file", path = "/path/to/file.txt" }
"#;
let config = Config::parse(toml).unwrap();
let source = &config.blocklist_sources[0];
assert!(source.enabled);
assert_eq!(source.format, BlocklistFormat::Domains);
}
#[test]
fn should_parse_disabled_blocklist_source() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "disabled-source"
enabled = false
source = { type = "remote", url = "https://example.com/blocklist.txt" }
"#;
let config = Config::parse(toml).unwrap();
assert!(!config.blocklist_sources[0].enabled);
}
#[test]
fn should_parse_blocklist_source_adblock_format() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "adguard"
source = { type = "remote", url = "https://example.com/filter.txt" }
format = "adblock"
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(config.blocklist_sources[0].format, BlocklistFormat::Adblock);
}
#[test]
fn should_parse_multiple_blocklist_sources() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "source-1"
source = { type = "file", path = "/path/1.txt" }
[[blocklist_sources]]
name = "source-2"
source = { type = "remote", url = "https://example.com/list.txt" }
format = "hosts"
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(config.blocklist_sources.len(), 2);
assert_eq!(config.blocklist_sources[0].name, "source-1");
assert_eq!(config.blocklist_sources[1].name, "source-2");
}
#[test]
fn should_parse_blocklist_sources_with_legacy_blocklist() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
blocklist = ["custom-domain.com"]
[[blocklist_sources]]
name = "remote-list"
source = { type = "remote", url = "https://example.com/list.txt" }
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(config.blocklist.len(), 1);
assert_eq!(config.blocklist[0], "custom-domain.com");
assert_eq!(config.blocklist_sources.len(), 1);
}
#[test]
fn should_reject_duplicate_blocklist_source_name() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "same-name"
source = { type = "file", path = "/path/1.txt" }
[[blocklist_sources]]
name = "same-name"
source = { type = "file", path = "/path/2.txt" }
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_empty_blocklist_source_name() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = ""
source = { type = "file", path = "/path/file.txt" }
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_empty_blocklist_source_path() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "test"
source = { type = "file", path = "" }
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_empty_blocklist_source_url() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "test"
source = { type = "remote", url = "" }
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_invalid_blocklist_source_url() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "test"
source = { type = "remote", url = "ftp://example.com/list.txt" }
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_unknown_blocklist_source_field() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "test"
source = { type = "file", path = "/path/file.txt" }
unknown_field = "value"
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_zero_buffer_pool_size() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
buffer_pool_size = 0
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_zero_channel_capacity() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
channel_capacity = 0
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_reject_invalid_wildcard_pattern() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
blocklist = ["*."]
"#;
assert!(Config::parse(toml).is_err());
}
#[test]
fn should_allow_file_source_with_refresh_interval() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
[[blocklist_sources]]
name = "local-custom"
source = { type = "file", path = "/etc/bluebox/custom-blocklist.txt" }
refresh_interval_hours = 24
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(config.blocklist_sources[0].refresh_interval_hours, Some(24));
}
#[test]
fn should_return_error_when_loading_nonexistent_file() {
let result = Config::load("/nonexistent/path/to/config.toml");
assert!(result.is_err());
}
#[test]
fn should_parse_blocklist_cache_dir() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
blocklist_cache_dir = "/var/cache/bluebox/blocklists"
"#;
let config = Config::parse(toml).unwrap();
assert_eq!(
config.blocklist_cache_dir,
Some(PathBuf::from("/var/cache/bluebox/blocklists"))
);
assert_eq!(
config.blocklist_cache_dir(),
PathBuf::from("/var/cache/bluebox/blocklists")
);
}
#[test]
fn should_use_default_blocklist_cache_dir_when_not_specified() {
let toml = r#"
upstream_resolver = "1.1.1.1:53"
"#;
let config = Config::parse(toml).unwrap();
assert!(config.blocklist_cache_dir.is_none());
let cache_dir = config.blocklist_cache_dir();
assert!(cache_dir.ends_with("bluebox/blocklists") || cache_dir.ends_with("blocklists"));
}
}