use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use parking_lot::RwLock;
use super::loader::{FileLoader, LoadError};
use super::remote::RemoteLoader;
use crate::config::{BlocklistSourceConfig, BlocklistSourceType, Config};
use crate::dns::Blocker;
#[derive(Debug, thiserror::Error)]
pub enum ManagerError {
#[error("unknown blocklist source: {0:?}")]
UnknownSource(String),
#[error("failed to load file blocklist")]
FileLoad(#[from] LoadError),
#[error("failed to load remote blocklist")]
RemoteLoad(#[from] super::remote::RemoteLoadError),
}
#[derive(Debug, Clone)]
pub struct SourceStats {
pub pattern_count: usize,
}
pub struct BlocklistManager {
blocker: Arc<RwLock<Blocker>>,
source_patterns: RwLock<HashMap<String, Vec<String>>>,
sources: Vec<BlocklistSourceConfig>,
inline_patterns: Vec<String>,
remote_loader: Option<RemoteLoader>,
}
impl BlocklistManager {
const INLINE_SOURCE_NAME: &'static str = "__inline__";
pub fn new(config: &Config) -> Result<Self, ManagerError> {
let cache_dir = config.blocklist_cache_dir();
let has_remote_sources = config
.blocklist_sources
.iter()
.any(|s| matches!(s.source, BlocklistSourceType::Remote { .. }));
let remote_loader = if has_remote_sources {
Some(RemoteLoader::new(cache_dir).map_err(ManagerError::RemoteLoad)?)
} else {
None
};
Ok(Self {
blocker: Arc::new(RwLock::new(Blocker::default())),
source_patterns: RwLock::new(HashMap::new()),
sources: config.blocklist_sources.clone(),
inline_patterns: config.blocklist.clone(),
remote_loader,
})
}
pub fn with_cache_dir(config: &Config, cache_dir: PathBuf) -> Result<Self, ManagerError> {
let has_remote_sources = config
.blocklist_sources
.iter()
.any(|s| matches!(s.source, BlocklistSourceType::Remote { .. }));
let remote_loader = if has_remote_sources {
Some(RemoteLoader::new(cache_dir).map_err(ManagerError::RemoteLoad)?)
} else {
None
};
Ok(Self {
blocker: Arc::new(RwLock::new(Blocker::default())),
source_patterns: RwLock::new(HashMap::new()),
sources: config.blocklist_sources.clone(),
inline_patterns: config.blocklist.clone(),
remote_loader,
})
}
#[must_use]
pub fn blocker(&self) -> Arc<RwLock<Blocker>> {
Arc::clone(&self.blocker)
}
pub async fn initialize(&self) -> Result<(), ManagerError> {
{
let mut patterns = self.source_patterns.write();
patterns.insert(
Self::INLINE_SOURCE_NAME.to_string(),
self.inline_patterns.clone(),
);
}
tracing::info!(
count = self.inline_patterns.len(),
"loaded inline blocklist patterns"
);
for source in &self.sources {
if !source.enabled {
tracing::debug!(name = ?source.name, "skipping disabled blocklist source");
continue;
}
match self.load_source(source).await {
Ok(patterns) => {
tracing::info!(
name = ?source.name,
count = patterns.len(),
"loaded blocklist source"
);
self.source_patterns
.write()
.insert(source.name.clone(), patterns);
}
Err(err) => {
tracing::error!(
name = ?source.name,
error = ?err,
"failed to load blocklist source"
);
}
}
}
self.rebuild_blocker();
Ok(())
}
fn rebuild_blocker(&self) {
let all_patterns: Vec<String> = self
.source_patterns
.read()
.values()
.flatten()
.cloned()
.collect();
let total_raw = all_patterns.len();
let new_blocker = Blocker::new(all_patterns);
let deduped_count = new_blocker.len();
tracing::info!(
raw_patterns = total_raw,
unique_patterns = deduped_count,
"rebuilt blocker"
);
*self.blocker.write() = new_blocker;
}
async fn load_source(
&self,
source: &BlocklistSourceConfig,
) -> Result<Vec<String>, ManagerError> {
match &source.source {
BlocklistSourceType::File { path } => {
tracing::debug!(name = ?source.name, path = ?path, "loading file blocklist");
let patterns = FileLoader::load(path, source.format).await?;
Ok(patterns)
}
BlocklistSourceType::Remote { url } => {
let loader = self.remote_loader.as_ref().ok_or_else(|| {
ManagerError::UnknownSource(format!(
"remote loader not available for source {0:?}",
source.name
))
})?;
tracing::debug!(name = ?source.name, url = %url, "loading remote blocklist");
let patterns = loader.load_cached(&source.name, url, source.format).await?;
Ok(patterns)
}
}
}
pub async fn refresh_source(&self, name: &str) -> Result<(), ManagerError> {
let source = self
.sources
.iter()
.find(|s| s.name == name)
.ok_or_else(|| ManagerError::UnknownSource(name.to_string()))?;
tracing::info!(name = ?name, "refreshing blocklist source");
let patterns = self.load_source(source).await?;
let count = patterns.len();
self.source_patterns
.write()
.insert(name.to_string(), patterns);
self.rebuild_blocker();
tracing::info!(name = ?name, count, "refreshed blocklist source");
Ok(())
}
pub async fn set_source_enabled(&self, name: &str, enabled: bool) -> Result<(), ManagerError> {
let source = self
.sources
.iter()
.find(|s| s.name == name)
.ok_or_else(|| ManagerError::UnknownSource(name.to_string()))?;
if enabled {
tracing::info!(name = ?name, "enabling blocklist source");
let patterns = self.load_source(source).await?;
let count = patterns.len();
self.source_patterns
.write()
.insert(name.to_string(), patterns);
tracing::info!(name = ?name, count, "enabled blocklist source");
} else {
tracing::info!(name = ?name, "disabling blocklist source");
self.source_patterns.write().remove(name);
}
self.rebuild_blocker();
Ok(())
}
#[must_use]
pub fn stats(&self) -> HashMap<String, SourceStats> {
self.source_patterns
.read()
.iter()
.map(|(name, patterns)| {
(
name.clone(),
SourceStats {
pattern_count: patterns.len(),
},
)
})
.collect()
}
#[must_use]
pub fn total_patterns(&self) -> usize {
self.blocker.read().len()
}
#[must_use]
pub fn is_source_loaded(&self, name: &str) -> bool {
self.source_patterns.read().contains_key(name)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::BlocklistFormat;
use std::fs;
use std::io::Write;
use tempfile::{NamedTempFile, TempDir};
fn create_test_config(sources: Vec<BlocklistSourceConfig>) -> Config {
Config {
interface: None,
upstream_resolver: "1.1.1.1:53".parse().unwrap(),
cache_ttl_seconds: 300,
blocklist: vec!["inline.example.com".to_string()],
blocklist_sources: sources,
blocklist_cache_dir: None,
buffer_pool_size: 64,
channel_capacity: 1000,
arp_spoof: crate::config::ArpSpoofSettings::default(),
metrics: crate::config::MetricsConfig::default(),
}
}
fn create_file_source(name: &str, path: &std::path::Path) -> BlocklistSourceConfig {
BlocklistSourceConfig {
name: name.to_string(),
enabled: true,
source: BlocklistSourceType::File {
path: path.to_path_buf(),
},
format: BlocklistFormat::Domains,
refresh_interval_hours: None,
}
}
#[tokio::test]
async fn should_load_inline_patterns_on_initialize() {
let config = create_test_config(vec![]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
let stats = manager.stats();
assert!(stats.contains_key(BlocklistManager::INLINE_SOURCE_NAME));
assert_eq!(stats[BlocklistManager::INLINE_SOURCE_NAME].pattern_count, 1);
assert_eq!(manager.total_patterns(), 1);
}
#[tokio::test]
async fn should_load_file_source_on_initialize() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "file1.example.com").unwrap();
writeln!(file, "file2.example.com").unwrap();
file.flush().unwrap();
let source = create_file_source("test-file", file.path());
let config = create_test_config(vec![source]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
let stats = manager.stats();
assert!(stats.contains_key("test-file"));
assert_eq!(stats["test-file"].pattern_count, 2);
assert_eq!(manager.total_patterns(), 3);
}
#[tokio::test]
async fn should_skip_disabled_sources() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "disabled.example.com").unwrap();
file.flush().unwrap();
let mut source = create_file_source("disabled-source", file.path());
source.enabled = false;
let config = create_test_config(vec![source]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
let stats = manager.stats();
assert!(!stats.contains_key("disabled-source"));
assert_eq!(manager.total_patterns(), 1);
}
#[tokio::test]
async fn should_refresh_source() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("blocklist.txt");
fs::write(&file_path, "original.example.com\n").unwrap();
let source = create_file_source("refreshable", &file_path);
let config = create_test_config(vec![source]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
assert_eq!(manager.stats()["refreshable"].pattern_count, 1);
fs::write(
&file_path,
"updated1.example.com\nupdated2.example.com\nupdated3.example.com\n",
)
.unwrap();
manager.refresh_source("refreshable").await.unwrap();
assert_eq!(manager.stats()["refreshable"].pattern_count, 3);
}
#[tokio::test]
async fn should_return_error_for_unknown_source_on_refresh() {
let config = create_test_config(vec![]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
let result = manager.refresh_source("nonexistent").await;
assert!(matches!(result, Err(ManagerError::UnknownSource(_))));
}
#[tokio::test]
async fn should_disable_source_at_runtime() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "removable.example.com").unwrap();
file.flush().unwrap();
let source = create_file_source("removable", file.path());
let config = create_test_config(vec![source]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
assert!(manager.is_source_loaded("removable"));
assert_eq!(manager.total_patterns(), 2);
manager
.set_source_enabled("removable", false)
.await
.unwrap();
assert!(!manager.is_source_loaded("removable"));
assert_eq!(manager.total_patterns(), 1);
}
#[tokio::test]
async fn should_enable_source_at_runtime() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "addable.example.com").unwrap();
file.flush().unwrap();
let mut source = create_file_source("addable", file.path());
source.enabled = false;
let config = create_test_config(vec![source]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
assert!(!manager.is_source_loaded("addable"));
assert_eq!(manager.total_patterns(), 1);
manager.set_source_enabled("addable", true).await.unwrap();
assert!(manager.is_source_loaded("addable"));
assert_eq!(manager.total_patterns(), 2);
}
#[tokio::test]
async fn should_share_blocker_across_clones() {
let config = create_test_config(vec![]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
let blocker1 = manager.blocker();
let blocker2 = manager.blocker();
assert!(Arc::ptr_eq(&blocker1, &blocker2));
}
#[tokio::test]
async fn should_update_shared_blocker_on_refresh() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("blocklist.txt");
fs::write(&file_path, "domain1.example.com\n").unwrap();
let source = create_file_source("test", &file_path);
let config = create_test_config(vec![source]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
let blocker = manager.blocker();
let initial_count = blocker.read().len();
fs::write(&file_path, "domain1.example.com\ndomain2.example.com\n").unwrap();
manager.refresh_source("test").await.unwrap();
let new_count = blocker.read().len();
assert!(new_count > initial_count);
}
#[tokio::test]
async fn should_continue_after_source_load_failure() {
let mut valid_file = NamedTempFile::new().unwrap();
writeln!(valid_file, "valid.example.com").unwrap();
valid_file.flush().unwrap();
let valid_source = create_file_source("valid", valid_file.path());
let invalid_source = BlocklistSourceConfig {
name: "invalid".to_string(),
enabled: true,
source: BlocklistSourceType::File {
path: "/nonexistent/path/blocklist.txt".into(),
},
format: BlocklistFormat::Domains,
refresh_interval_hours: None,
};
let config = create_test_config(vec![invalid_source, valid_source]);
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
let stats = manager.stats();
assert!(stats.contains_key("valid"));
assert!(!stats.contains_key("invalid"));
}
#[tokio::test]
async fn should_deduplicate_patterns_across_sources() {
let mut file1 = NamedTempFile::new().unwrap();
writeln!(file1, "duplicate.example.com").unwrap();
writeln!(file1, "unique1.example.com").unwrap();
file1.flush().unwrap();
let mut file2 = NamedTempFile::new().unwrap();
writeln!(file2, "duplicate.example.com").unwrap();
writeln!(file2, "unique2.example.com").unwrap();
file2.flush().unwrap();
let source1 = create_file_source("source1", file1.path());
let source2 = create_file_source("source2", file2.path());
let mut config = create_test_config(vec![source1, source2]);
config.blocklist.clear();
let manager = BlocklistManager::new(&config).unwrap();
manager.initialize().await.unwrap();
let stats = manager.stats();
assert_eq!(stats["source1"].pattern_count, 2);
assert_eq!(stats["source2"].pattern_count, 2);
assert_eq!(manager.total_patterns(), 3);
}
#[tokio::test]
async fn should_create_manager_with_custom_cache_dir() {
let temp_dir = TempDir::new().unwrap();
let config = create_test_config(vec![]);
let manager =
BlocklistManager::with_cache_dir(&config, temp_dir.path().to_path_buf()).unwrap();
manager.initialize().await.unwrap();
assert_eq!(manager.total_patterns(), 1); }
}