use std::{sync::Arc, time::Duration};
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use crate::{
blocklist::{
aggregate::Aggregator,
fetch::{FetchOutcome, Fetcher, Validators},
parse::{BlocklistParser as _, Parser},
},
resolver::state::ResolverState,
storage::blocklists::{BlocklistRepository, RefreshMetadata, SqliteBlocklistRepo},
time::Clock,
};
const MIN_INTERVAL: Duration = Duration::from_secs(60);
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct RefreshSummary {
pub fetched: usize,
pub not_modified: usize,
pub failed: usize,
pub total_domains: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum RefreshError {
#[error("failed to list enabled blocklist sources: {0}")]
Storage(#[from] crate::storage::Error),
}
#[derive(Debug, Clone)]
pub struct RefreshTrigger(Arc<Notify>);
impl RefreshTrigger {
pub fn trigger(&self) {
self.0.notify_one();
}
}
pub struct BlocklistScheduler {
repo: SqliteBlocklistRepo,
state: Arc<ResolverState>,
fetcher: Fetcher,
notify: Arc<Notify>,
}
impl BlocklistScheduler {
pub fn new(repo: SqliteBlocklistRepo, state: Arc<ResolverState>, fetcher: Fetcher) -> Self {
Self {
repo,
state,
fetcher,
notify: Arc::new(Notify::new()),
}
}
pub fn trigger(&self) -> RefreshTrigger {
RefreshTrigger(Arc::clone(&self.notify))
}
fn decode_and_add(
aggregator: &mut Aggregator<i64>,
source_id: i64,
format: crate::storage::blocklists::BlocklistFormat,
content: &[u8],
) {
let text = String::from_utf8_lossy(content);
let names = Parser::from(format).parse(&text);
aggregator.add(source_id, names);
}
pub async fn load_from_cache(&self) {
let sources = match self.repo.list_enabled().await {
Ok(s) => s,
Err(e) => {
warn!(error = %e, "offline cache load: failed to list sources, skipping");
return;
}
};
let mut aggregator: Aggregator<i64> = Aggregator::new();
let mut loaded = 0usize;
for source in &sources {
let cached = match self.repo.load_cache(source.id).await {
Ok(Some(c)) => c,
Ok(None) => continue,
Err(e) => {
warn!(
id = source.id,
url = %source.url,
error = %e,
"offline cache load: failed to read cache, skipping source"
);
continue;
}
};
Self::decode_and_add(&mut aggregator, source.id, source.format, &cached.content);
loaded += 1;
}
let total = aggregator.len();
let _ = aggregator.install(self.state.blocklist());
info!(
sources_loaded = loaded,
total_domains = total,
"offline cache load complete"
);
}
pub async fn refresh_once(&self) -> Result<RefreshSummary, RefreshError> {
let sources = self.repo.list_enabled().await?;
let mut aggregator: Aggregator<i64> = Aggregator::new();
let mut summary = RefreshSummary::default();
for source in &sources {
let validators = Validators {
etag: source.etag.clone(),
last_modified: source.last_modified.clone(),
};
match self.fetcher.fetch(&source.url, &validators).await {
Ok(FetchOutcome::Modified {
body,
validators: new_validators,
}) => {
let text = String::from_utf8_lossy(&body);
let names = Parser::from(source.format).parse(&text);
let count = names.len();
aggregator.add(source.id, names);
if let Err(e) = self.repo.save_cache(source.id, &body).await {
warn!(
id = source.id,
url = %source.url,
error = %e,
"refresh: failed to save cache (continuing)"
);
}
let meta = RefreshMetadata {
entry_count: count as u64,
last_updated: Clock::now_secs(),
etag: new_validators.etag,
last_modified: new_validators.last_modified,
};
if let Err(e) = self.repo.update_refresh_metadata(source.id, &meta).await {
warn!(
id = source.id,
url = %source.url,
error = %e,
"refresh: failed to update metadata (continuing)"
);
}
summary.fetched += 1;
info!(
id = source.id,
url = %source.url,
domains = count,
"refresh: source updated (200)"
);
}
Ok(FetchOutcome::NotModified) => {
match self.repo.load_cache(source.id).await {
Ok(Some(cached)) => {
Self::decode_and_add(
&mut aggregator,
source.id,
source.format,
&cached.content,
);
}
Ok(None) => {
warn!(
id = source.id,
url = %source.url,
"refresh: 304 but no cached content found — source skipped"
);
}
Err(e) => {
warn!(
id = source.id,
url = %source.url,
error = %e,
"refresh: 304 but cache read failed — source skipped"
);
}
}
summary.not_modified += 1;
info!(
id = source.id,
url = %source.url,
"refresh: source not modified (304)"
);
}
Err(e) => {
warn!(
id = source.id,
url = %source.url,
error = %e,
"refresh: fetch failed, falling back to cached content"
);
match self.repo.load_cache(source.id).await {
Ok(Some(cached)) => {
Self::decode_and_add(
&mut aggregator,
source.id,
source.format,
&cached.content,
);
}
Ok(None) => {
warn!(
id = source.id,
url = %source.url,
"refresh: fetch failed and no cached content — source dropped"
);
}
Err(ce) => {
warn!(
id = source.id,
url = %source.url,
cache_error = %ce,
"refresh: fetch failed and cache read also failed — source dropped"
);
}
}
summary.failed += 1;
}
}
}
let contributions = aggregator.install(self.state.blocklist());
summary.total_domains = self.state.blocklist().len();
info!(
fetched = summary.fetched,
not_modified = summary.not_modified,
failed = summary.failed,
total_domains = summary.total_domains,
sources = contributions.len(),
"refresh cycle complete"
);
Ok(summary)
}
pub async fn run(self, token: CancellationToken) {
match self.refresh_once().await {
Ok(summary) => {
info!(
fetched = summary.fetched,
not_modified = summary.not_modified,
failed = summary.failed,
total_domains = summary.total_domains,
"startup refresh complete"
);
}
Err(e) => {
warn!(error = %e, "startup refresh failed — will retry at next interval");
}
}
loop {
let interval_secs = self.state.settings().blocklist_refresh_interval;
let interval = Duration::from_secs(u64::from(interval_secs)).max(MIN_INTERVAL);
tokio::select! {
() = tokio::time::sleep(interval) => {
match self.refresh_once().await {
Ok(summary) => {
info!(
fetched = summary.fetched,
not_modified = summary.not_modified,
failed = summary.failed,
total_domains = summary.total_domains,
"periodic refresh complete"
);
}
Err(e) => {
warn!(error = %e, "periodic refresh failed — will retry at next interval");
}
}
}
() = self.notify.notified() => {
match self.refresh_once().await {
Ok(summary) => {
info!(
fetched = summary.fetched,
not_modified = summary.not_modified,
failed = summary.failed,
total_domains = summary.total_domains,
"on-demand refresh complete"
);
}
Err(e) => {
warn!(error = %e, "on-demand refresh failed");
}
}
}
() = token.cancelled() => {
info!("blocklist scheduler shutting down");
break;
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tempfile::TempDir;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::*;
use crate::{
blocklist::fetch::Fetcher,
storage::{
Db,
blocklists::{BlocklistFormat, BlocklistRepository, NewBlocklist, SqliteBlocklistRepo},
},
};
async fn open_db() -> (TempDir, Db) {
let dir = TempDir::new().expect("temp dir");
let path = dir.path().join("test.db");
let db = Db::connect(&path).await.expect("connect");
(dir, db)
}
fn make_scheduler(db: &Db, state: Arc<ResolverState>) -> BlocklistScheduler {
BlocklistScheduler::new(
SqliteBlocklistRepo::new(db.pool().clone()),
state,
Fetcher::new().with_timeout(Duration::from_secs(5)),
)
}
fn hosts_source(url: &str) -> NewBlocklist {
NewBlocklist {
url: url.to_owned(),
format: BlocklistFormat::Hosts,
enabled: true,
}
}
#[tokio::test]
async fn load_from_cache_builds_set_without_network() {
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = SqliteBlocklistRepo::new(db.pool().clone());
let src = repo
.insert(hosts_source("https://offline.example.com/hosts"))
.await
.expect("insert");
let body = b"0.0.0.0 ads.example.com\n0.0.0.0 tracker.example.org\n";
repo.save_cache(src.id, body).await.expect("save_cache");
assert!(state.blocklist().is_empty());
let scheduler = make_scheduler(&db, Arc::clone(&state));
scheduler.load_from_cache().await;
let ads: crate::codec::name::Name = "ads.example.com".parse().unwrap();
let tracker: crate::codec::name::Name = "tracker.example.org".parse().unwrap();
assert!(
state.blocklist().contains(&ads),
"ads.example.com must be blocked after cache load"
);
assert!(
state.blocklist().contains(&tracker),
"tracker.example.org must be blocked after cache load"
);
assert_eq!(state.blocklist().len(), 2);
}
#[tokio::test]
async fn refresh_once_200_installs_domains_and_persists_metadata() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(
ResponseTemplate::new(200)
.set_body_bytes(
b"0.0.0.0 ads.example.com\n0.0.0.0 tracker.example.org\n".to_vec(),
)
.insert_header("etag", r#""v1""#),
)
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = SqliteBlocklistRepo::new(db.pool().clone());
let src = repo.insert(hosts_source(&url)).await.expect("insert");
let scheduler = make_scheduler(&db, Arc::clone(&state));
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.fetched, 1);
assert_eq!(summary.not_modified, 0);
assert_eq!(summary.failed, 0);
assert_eq!(summary.total_domains, 2);
let ads: crate::codec::name::Name = "ads.example.com".parse().unwrap();
let tracker: crate::codec::name::Name = "tracker.example.org".parse().unwrap();
assert!(state.blocklist().contains(&ads));
assert!(state.blocklist().contains(&tracker));
let cached = repo
.load_cache(src.id)
.await
.expect("load_cache")
.expect("cache must be Some after refresh");
assert!(!cached.content.is_empty());
let rows = repo.list().await.expect("list");
let row = rows.iter().find(|r| r.id == src.id).expect("row");
assert_eq!(row.entry_count, 2);
assert!(row.last_updated.is_some(), "last_updated must be set");
assert_eq!(row.etag.as_deref(), Some(r#""v1""#));
}
#[tokio::test]
async fn refresh_once_304_retains_cached_domains() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.and(header(
reqwest::header::IF_NONE_MATCH.as_str(),
r#""etag-v1""#,
))
.respond_with(ResponseTemplate::new(304))
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = SqliteBlocklistRepo::new(db.pool().clone());
let src = repo
.insert(NewBlocklist {
url,
format: BlocklistFormat::Hosts,
enabled: true,
})
.await
.expect("insert");
repo.update_refresh_metadata(
src.id,
&RefreshMetadata {
entry_count: 2,
last_updated: 1_700_000_000,
etag: Some(r#""etag-v1""#.to_owned()),
last_modified: None,
},
)
.await
.expect("update meta");
let body = b"0.0.0.0 ads.example.com\n0.0.0.0 tracker.example.org\n";
repo.save_cache(src.id, body).await.expect("save_cache");
let scheduler = make_scheduler(&db, Arc::clone(&state));
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.not_modified, 1);
assert_eq!(summary.fetched, 0);
assert_eq!(summary.failed, 0);
let ads: crate::codec::name::Name = "ads.example.com".parse().unwrap();
let tracker: crate::codec::name::Name = "tracker.example.org".parse().unwrap();
assert!(
state.blocklist().contains(&ads),
"ads must be present after 304"
);
assert!(
state.blocklist().contains(&tracker),
"tracker must be present after 304"
);
}
#[tokio::test]
async fn refresh_once_bad_source_retains_cached_domains_and_does_not_sink_good_source() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/good.txt"))
.respond_with(
ResponseTemplate::new(200).set_body_bytes(b"0.0.0.0 good.example.com\n".to_vec()),
)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/bad.txt"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = SqliteBlocklistRepo::new(db.pool().clone());
repo.insert(NewBlocklist {
url: format!("{}/good.txt", server.uri()),
format: BlocklistFormat::Hosts,
enabled: true,
})
.await
.expect("insert good");
let bad = repo
.insert(NewBlocklist {
url: format!("{}/bad.txt", server.uri()),
format: BlocklistFormat::DomainList,
enabled: true,
})
.await
.expect("insert bad");
repo.save_cache(bad.id, b"bad-but-cached.example.com\n")
.await
.expect("seed bad cache");
let scheduler = make_scheduler(&db, Arc::clone(&state));
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.fetched, 1, "good source must count as fetched");
assert_eq!(summary.failed, 1, "bad source must count as failed");
let good: crate::codec::name::Name = "good.example.com".parse().unwrap();
let cached: crate::codec::name::Name = "bad-but-cached.example.com".parse().unwrap();
assert!(
state.blocklist().contains(&good),
"good domain must be present"
);
assert!(
state.blocklist().contains(&cached),
"bad source's cached domain must be retained"
);
assert_eq!(state.blocklist().len(), 2, "exactly 2 domains in live set");
}
#[tokio::test]
async fn run_on_demand_trigger_forces_refresh() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(
ResponseTemplate::new(200).set_body_bytes(b"0.0.0.0 v1.example.com\n".to_vec()),
)
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = SqliteBlocklistRepo::new(db.pool().clone());
repo.insert(hosts_source(&url)).await.expect("insert");
let scheduler = BlocklistScheduler::new(
SqliteBlocklistRepo::new(db.pool().clone()),
Arc::clone(&state),
Fetcher::new().with_timeout(Duration::from_secs(5)),
);
let trigger = scheduler.trigger();
let token = CancellationToken::new();
let token_clone = token.clone();
let task = tokio::spawn(async move {
scheduler.run(token_clone).await;
});
let v1: crate::codec::name::Name = "v1.example.com".parse().unwrap();
let deadline = tokio::time::Instant::now() + Duration::from_secs(10);
loop {
if state.blocklist().contains(&v1) {
break;
}
assert!(
tokio::time::Instant::now() < deadline,
"timed out waiting for v1 to appear in blocklist"
);
tokio::time::sleep(Duration::from_millis(10)).await;
}
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(
ResponseTemplate::new(200).set_body_bytes(b"0.0.0.0 v2.example.com\n".to_vec()),
)
.with_priority(1)
.mount(&server)
.await;
trigger.trigger();
let v2: crate::codec::name::Name = "v2.example.com".parse().unwrap();
let deadline2 = tokio::time::Instant::now() + Duration::from_secs(10);
loop {
if state.blocklist().contains(&v2) {
break;
}
assert!(
tokio::time::Instant::now() < deadline2,
"timed out waiting for v2 to appear in blocklist after trigger"
);
tokio::time::sleep(Duration::from_millis(10)).await;
}
token.cancel();
tokio::time::timeout(Duration::from_secs(5), task)
.await
.expect("scheduler task timed out on shutdown")
.expect("scheduler task panicked");
}
#[tokio::test]
async fn zero_sources_load_and_refresh_no_panic() {
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let scheduler = make_scheduler(&db, Arc::clone(&state));
scheduler.load_from_cache().await;
assert!(state.blocklist().is_empty());
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.fetched, 0);
assert_eq!(summary.not_modified, 0);
assert_eq!(summary.failed, 0);
assert_eq!(summary.total_domains, 0);
assert!(state.blocklist().is_empty());
}
}