use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::time::Duration;
use reqwest::Client;
use tokio::fs::{self, File};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::{ParseError, parser_for_format};
use crate::config::BlocklistFormat;
const DEFAULT_TIMEOUT_SECS: u64 = 30;
const USER_AGENT: &str = concat!("bluebox/", env!("CARGO_PKG_VERSION"));
#[derive(Debug, thiserror::Error)]
pub enum RemoteLoadError {
#[error("HTTP request failed for {url}: status {status}")]
HttpStatus {
url: String,
status: u16,
},
#[error("network error fetching {url}")]
Network {
url: String,
#[source]
source: reqwest::Error,
},
#[error("timeout fetching {url}")]
Timeout {
url: String,
},
#[error("parse error")]
Parse(#[from] ParseError),
#[error("task join error")]
Join(#[from] tokio::task::JoinError),
#[error("cache not available: {0:?}")]
CacheUnavailable(PathBuf),
#[error("cache I/O error for {path:?}")]
CacheIo {
path: PathBuf,
#[source]
source: std::io::Error,
},
#[error("failed to create HTTP client")]
ClientBuild(#[source] reqwest::Error),
}
pub struct RemoteLoader {
client: Client,
cache_dir: PathBuf,
}
impl RemoteLoader {
pub fn new(cache_dir: PathBuf) -> Result<Self, RemoteLoadError> {
let client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.user_agent(USER_AGENT)
.gzip(true)
.build()
.map_err(RemoteLoadError::ClientBuild)?;
Ok(Self { client, cache_dir })
}
pub async fn load(
&self,
url: &str,
format: BlocklistFormat,
) -> Result<Vec<String>, RemoteLoadError> {
let response = self.client.get(url).send().await.map_err(|err| {
if err.is_timeout() {
RemoteLoadError::Timeout {
url: url.to_string(),
}
} else {
RemoteLoadError::Network {
url: url.to_string(),
source: err,
}
}
})?;
if !response.status().is_success() {
return Err(RemoteLoadError::HttpStatus {
url: url.to_string(),
status: response.status().as_u16(),
});
}
let content = response
.text()
.await
.map_err(|err| RemoteLoadError::Network {
url: url.to_string(),
source: err,
})?;
let domains = tokio::task::spawn_blocking(move || {
let parser = parser_for_format(format);
let mut reader = BufReader::new(content.as_bytes());
parser.parse(&mut reader)
})
.await??;
Ok(domains)
}
pub async fn load_cached(
&self,
name: &str,
url: &str,
format: BlocklistFormat,
) -> Result<Vec<String>, RemoteLoadError> {
let cache_path = self.cache_path(name);
match self.load_and_cache(url, format, &cache_path).await {
Ok(patterns) => Ok(patterns),
Err(err) => {
tracing::warn!(
url = %url,
error = ?err,
"failed to fetch remote blocklist, trying cache"
);
self.load_from_cache(&cache_path, format).await
}
}
}
async fn load_and_cache(
&self,
url: &str,
format: BlocklistFormat,
cache_path: &Path,
) -> Result<Vec<String>, RemoteLoadError> {
let response = self.client.get(url).send().await.map_err(|err| {
if err.is_timeout() {
RemoteLoadError::Timeout {
url: url.to_string(),
}
} else {
RemoteLoadError::Network {
url: url.to_string(),
source: err,
}
}
})?;
if !response.status().is_success() {
return Err(RemoteLoadError::HttpStatus {
url: url.to_string(),
status: response.status().as_u16(),
});
}
let content = response
.text()
.await
.map_err(|err| RemoteLoadError::Network {
url: url.to_string(),
source: err,
})?;
if let Err(err) = self.save_cache(cache_path, &content).await {
tracing::warn!(
path = ?cache_path,
error = ?err,
"failed to save blocklist to cache"
);
}
let domains = tokio::task::spawn_blocking(move || {
let parser = parser_for_format(format);
let mut reader = BufReader::new(content.as_bytes());
parser.parse(&mut reader)
})
.await??;
Ok(domains)
}
async fn save_cache(&self, cache_path: &Path, content: &str) -> Result<(), RemoteLoadError> {
if let Some(parent) = cache_path.parent() {
fs::create_dir_all(parent)
.await
.map_err(|err| RemoteLoadError::CacheIo {
path: parent.to_path_buf(),
source: err,
})?;
}
let mut file = File::create(cache_path)
.await
.map_err(|err| RemoteLoadError::CacheIo {
path: cache_path.to_path_buf(),
source: err,
})?;
file.write_all(content.as_bytes())
.await
.map_err(|err| RemoteLoadError::CacheIo {
path: cache_path.to_path_buf(),
source: err,
})?;
file.flush().await.map_err(|err| RemoteLoadError::CacheIo {
path: cache_path.to_path_buf(),
source: err,
})?;
tracing::debug!(path = ?cache_path, "saved blocklist to cache");
Ok(())
}
async fn load_from_cache(
&self,
cache_path: &Path,
format: BlocklistFormat,
) -> Result<Vec<String>, RemoteLoadError> {
let mut file = File::open(cache_path).await.map_err(|err| {
if err.kind() == std::io::ErrorKind::NotFound {
RemoteLoadError::CacheUnavailable(cache_path.to_path_buf())
} else {
RemoteLoadError::CacheIo {
path: cache_path.to_path_buf(),
source: err,
}
}
})?;
let mut content = String::new();
file.read_to_string(&mut content)
.await
.map_err(|err| RemoteLoadError::CacheIo {
path: cache_path.to_path_buf(),
source: err,
})?;
tracing::info!(path = ?cache_path, "loaded blocklist from cache");
let domains = tokio::task::spawn_blocking(move || {
let parser = parser_for_format(format);
let mut reader = BufReader::new(content.as_bytes());
parser.parse(&mut reader)
})
.await??;
Ok(domains)
}
fn cache_path(&self, name: &str) -> PathBuf {
self.cache_dir.join(format!("{name}.cache"))
}
}
#[must_use]
pub fn default_cache_dir() -> PathBuf {
dirs::cache_dir().map_or_else(
|| PathBuf::from("./cache/blocklists"),
|p| p.join("bluebox").join("blocklists"),
)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn create_loader() -> (RemoteLoader, TempDir) {
let temp_dir = TempDir::new().unwrap();
let loader = RemoteLoader::new(temp_dir.path().to_path_buf()).unwrap();
(loader, temp_dir)
}
#[tokio::test]
async fn should_load_domains_format_from_url() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/blocklist.txt"))
.respond_with(
ResponseTemplate::new(200).set_body_string("# Comment\nexample.com\n*.ads.com"),
)
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/blocklist.txt", mock_server.uri());
let domains = loader.load(&url, BlocklistFormat::Domains).await.unwrap();
assert_eq!(domains, vec!["example.com", "*.ads.com"]);
}
#[tokio::test]
async fn should_load_hosts_format_from_url() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"# Hosts file\n0.0.0.0 ads.example.com\n127.0.0.1 tracking.example.com",
))
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/hosts", mock_server.uri());
let domains = loader.load(&url, BlocklistFormat::Hosts).await.unwrap();
assert_eq!(domains, vec!["ads.example.com", "tracking.example.com"]);
}
#[tokio::test]
async fn should_load_adblock_format_from_url() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/filter.txt"))
.respond_with(ResponseTemplate::new(200).set_body_string(
"! AdBlock comment\n||ads.example.com^\n||tracking.example.com^$third-party",
))
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/filter.txt", mock_server.uri());
let domains = loader.load(&url, BlocklistFormat::Adblock).await.unwrap();
assert_eq!(domains, vec!["ads.example.com", "tracking.example.com"]);
}
#[tokio::test]
async fn should_return_http_status_error_when_not_found() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/notfound.txt"))
.respond_with(ResponseTemplate::new(404))
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/notfound.txt", mock_server.uri());
let result = loader.load(&url, BlocklistFormat::Domains).await;
assert!(matches!(
result,
Err(RemoteLoadError::HttpStatus { status: 404, .. })
));
}
#[tokio::test]
async fn should_return_http_status_error_when_server_error() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/error.txt"))
.respond_with(ResponseTemplate::new(500))
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/error.txt", mock_server.uri());
let result = loader.load(&url, BlocklistFormat::Domains).await;
assert!(matches!(
result,
Err(RemoteLoadError::HttpStatus { status: 500, .. })
));
}
#[tokio::test]
async fn should_return_network_error_when_connection_refused() {
let (loader, _temp) = create_loader();
let url = "http://127.0.0.1:1/blocklist.txt";
let result = loader.load(url, BlocklistFormat::Domains).await;
assert!(matches!(result, Err(RemoteLoadError::Network { .. })));
}
#[tokio::test]
async fn should_cache_response_on_successful_load() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/blocklist.txt"))
.respond_with(ResponseTemplate::new(200).set_body_string("example.com"))
.expect(1)
.mount(&mock_server)
.await;
let (loader, temp_dir) = create_loader();
let url = format!("{}/blocklist.txt", mock_server.uri());
let domains = loader
.load_cached("test", &url, BlocklistFormat::Domains)
.await
.unwrap();
assert_eq!(domains, vec!["example.com"]);
let cache_path = temp_dir.path().join("test.cache");
assert!(cache_path.exists());
let cached_content = std::fs::read_to_string(&cache_path).unwrap();
assert_eq!(cached_content, "example.com");
}
#[tokio::test]
async fn should_fallback_to_cache_when_remote_fails() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/blocklist.txt"))
.respond_with(ResponseTemplate::new(200).set_body_string("example.com"))
.expect(1)
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/blocklist.txt", mock_server.uri());
let domains = loader
.load_cached("test", &url, BlocklistFormat::Domains)
.await
.unwrap();
assert_eq!(domains, vec!["example.com"]);
mock_server.reset().await;
Mock::given(method("GET"))
.and(path("/blocklist.txt"))
.respond_with(ResponseTemplate::new(503))
.expect(1)
.mount(&mock_server)
.await;
let domains = loader
.load_cached("test", &url, BlocklistFormat::Domains)
.await
.unwrap();
assert_eq!(domains, vec!["example.com"]);
}
#[tokio::test]
async fn should_return_cache_unavailable_when_no_cache_exists() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/blocklist.txt"))
.respond_with(ResponseTemplate::new(503))
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/blocklist.txt", mock_server.uri());
let result = loader
.load_cached("nonexistent", &url, BlocklistFormat::Domains)
.await;
assert!(matches!(result, Err(RemoteLoadError::CacheUnavailable(_))));
}
#[tokio::test]
async fn should_handle_empty_response() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/empty.txt"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/empty.txt", mock_server.uri());
let domains = loader.load(&url, BlocklistFormat::Domains).await.unwrap();
assert!(domains.is_empty());
}
#[tokio::test]
async fn should_handle_comments_only_response() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/comments.txt"))
.respond_with(ResponseTemplate::new(200).set_body_string("# Comment 1\n# Comment 2\n"))
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/comments.txt", mock_server.uri());
let domains = loader.load(&url, BlocklistFormat::Domains).await.unwrap();
assert!(domains.is_empty());
}
#[tokio::test]
async fn should_include_user_agent_header() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/blocklist.txt"))
.and(wiremock::matchers::header("User-Agent", USER_AGENT))
.respond_with(ResponseTemplate::new(200).set_body_string("example.com"))
.expect(1)
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/blocklist.txt", mock_server.uri());
let domains = loader.load(&url, BlocklistFormat::Domains).await.unwrap();
assert_eq!(domains, vec!["example.com"]);
}
#[tokio::test]
async fn should_handle_redirect_responses() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redirect"))
.respond_with(
ResponseTemplate::new(302)
.append_header("Location", format!("{}/actual", mock_server.uri())),
)
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/actual"))
.respond_with(ResponseTemplate::new(200).set_body_string("example.com"))
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/redirect", mock_server.uri());
let domains = loader.load(&url, BlocklistFormat::Domains).await.unwrap();
assert_eq!(domains, vec!["example.com"]);
}
#[test]
fn should_return_default_cache_dir() {
let cache_dir = default_cache_dir();
assert!(cache_dir.ends_with("bluebox/blocklists") || cache_dir.ends_with("blocklists"));
}
#[tokio::test]
async fn should_handle_large_response() {
use std::fmt::Write;
let mock_server = MockServer::start().await;
let mut content = String::new();
for i in 0..10_000 {
writeln!(content, "domain{i}.example.com").unwrap();
}
Mock::given(method("GET"))
.and(path("/large.txt"))
.respond_with(ResponseTemplate::new(200).set_body_string(content))
.mount(&mock_server)
.await;
let (loader, _temp) = create_loader();
let url = format!("{}/large.txt", mock_server.uri());
let domains = loader.load(&url, BlocklistFormat::Domains).await.unwrap();
assert_eq!(domains.len(), 10_000);
assert_eq!(domains[0], "domain0.example.com");
assert_eq!(domains[9999], "domain9999.example.com");
}
}