use std::{sync::Arc, time::Duration};
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use crate::{
blocklist::{
aggregate::Aggregator,
fetch::{FetchOutcome, Fetcher, Validators},
parse::{BlocklistParser as _, Parser},
},
resolver::state::ResolverState,
storage::blocklists::{Blocklist, BlocklistRepository, RefreshMetadata, SqliteBlocklistRepo},
time::Clock,
};
const MIN_INTERVAL: Duration = Duration::from_secs(60);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SourceRefresh {
Contributed,
Incomplete,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct RefreshSummary {
pub fetched: usize,
pub not_modified: usize,
pub failed: usize,
pub total_domains: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum RefreshError {
#[error("failed to list enabled blocklist sources: {0}")]
Storage(#[from] crate::storage::Error),
}
#[derive(Debug, Clone)]
pub struct RefreshTrigger(Arc<Notify>);
impl RefreshTrigger {
pub fn trigger(&self) {
self.0.notify_one();
}
}
pub struct BlocklistScheduler {
repo: SqliteBlocklistRepo,
state: Arc<ResolverState>,
fetcher: Fetcher,
notify: Arc<Notify>,
}
impl BlocklistScheduler {
#[must_use]
pub fn new(repo: SqliteBlocklistRepo, state: Arc<ResolverState>, fetcher: Fetcher) -> Self {
Self {
repo,
state,
fetcher,
notify: Arc::new(Notify::new()),
}
}
#[must_use]
pub fn trigger(&self) -> RefreshTrigger {
RefreshTrigger(Arc::clone(&self.notify))
}
fn decode_and_add(
aggregator: &mut Aggregator,
source_id: i64,
format: crate::storage::blocklists::BlocklistFormat,
content: &[u8],
) {
let text = String::from_utf8_lossy(content);
let names = Parser::from(format).parse(&text);
aggregator.add(source_id, names);
}
async fn refresh_source(
&self,
source: &Blocklist,
aggregator: &mut Aggregator,
summary: &mut RefreshSummary,
) -> SourceRefresh {
let validators = Validators {
etag: source.etag.clone(),
last_modified: source.last_modified.clone(),
};
match self.fetcher.fetch(&source.url, &validators).await {
Ok(FetchOutcome::Modified { body, validators }) => {
self.handle_modified(source, body, validators, aggregator, summary)
.await
}
Ok(FetchOutcome::NotModified) => {
self.handle_not_modified(source, aggregator, summary).await
}
Err(e) => {
self.handle_fetch_error(source, &e, aggregator, summary)
.await
}
}
}
async fn handle_modified(
&self,
source: &Blocklist,
body: bytes::Bytes,
validators: Validators,
aggregator: &mut Aggregator,
summary: &mut RefreshSummary,
) -> SourceRefresh {
let text = String::from_utf8_lossy(&body);
let names = Parser::from(source.format).parse(&text);
let count = names.len();
if count == 0 {
warn!(
id = source.id,
url = %source.url,
"refresh: 200 but parsed zero domains; keeping last-good cache"
);
summary.failed += 1;
return self
.add_cached_source(source, aggregator, "empty 200 body")
.await;
}
aggregator.add(source.id, names);
if let Err(e) = self.repo.save_cache(source.id, &body).await {
warn!(
id = source.id,
url = %source.url,
error = %e,
"refresh: failed to save cache (continuing)"
);
}
let meta = RefreshMetadata {
entry_count: count as u64,
last_updated: Clock::now_secs(),
etag: validators.etag,
last_modified: validators.last_modified,
};
if let Err(e) = self.repo.update_refresh_metadata(source.id, &meta).await {
warn!(
id = source.id,
url = %source.url,
error = %e,
"refresh: failed to update metadata (continuing)"
);
}
summary.fetched += 1;
info!(
id = source.id,
url = %source.url,
domains = count,
"refresh: source updated (200)"
);
SourceRefresh::Contributed
}
async fn handle_not_modified(
&self,
source: &Blocklist,
aggregator: &mut Aggregator,
summary: &mut RefreshSummary,
) -> SourceRefresh {
let result = self.add_cached_source(source, aggregator, "304").await;
summary.not_modified += 1;
if result == SourceRefresh::Incomplete {
summary.failed += 1;
}
info!(
id = source.id,
url = %source.url,
"refresh: source not modified (304)"
);
result
}
async fn handle_fetch_error(
&self,
source: &Blocklist,
error: &crate::blocklist::fetch::FetchError,
aggregator: &mut Aggregator,
summary: &mut RefreshSummary,
) -> SourceRefresh {
warn!(
id = source.id,
url = %source.url,
error = %error,
"refresh: fetch failed, falling back to cached content"
);
summary.failed += 1;
self.add_cached_source(source, aggregator, "fetch failed")
.await
}
async fn add_cached_source(
&self,
source: &Blocklist,
aggregator: &mut Aggregator,
context: &'static str,
) -> SourceRefresh {
match self.repo.load_cache(source.id).await {
Ok(Some(cached)) => {
Self::decode_and_add(aggregator, source.id, source.format, &cached.content);
SourceRefresh::Contributed
}
Ok(None) => {
warn!(
id = source.id,
url = %source.url,
"refresh: {context} but no cached content — source skipped"
);
SourceRefresh::Incomplete
}
Err(e) => {
warn!(
id = source.id,
url = %source.url,
error = %e,
"refresh: {context} but cache read failed — source skipped"
);
SourceRefresh::Incomplete
}
}
}
pub async fn load_from_cache(&self) {
let sources = match self.repo.list_enabled().await {
Ok(s) => s,
Err(e) => {
warn!(error = %e, "offline cache load: failed to list sources, skipping");
return;
}
};
let mut aggregator: Aggregator = Aggregator::new();
let mut loaded = 0usize;
for source in &sources {
let cached = match self.repo.load_cache(source.id).await {
Ok(Some(c)) => c,
Ok(None) => continue,
Err(e) => {
warn!(
id = source.id,
url = %source.url,
error = %e,
"offline cache load: failed to read cache, skipping source"
);
continue;
}
};
Self::decode_and_add(&mut aggregator, source.id, source.format, &cached.content);
loaded += 1;
}
let total = aggregator.len();
let _ = aggregator.install(self.state.blocklist());
info!(
sources_loaded = loaded,
total_domains = total,
"offline cache load complete"
);
}
pub async fn refresh_once(&self) -> Result<RefreshSummary, RefreshError> {
let sources = self.repo.list_enabled().await?;
let mut aggregator: Aggregator = Aggregator::new();
let mut summary = RefreshSummary::default();
let mut complete_snapshot = true;
for source in &sources {
if self
.refresh_source(source, &mut aggregator, &mut summary)
.await
== SourceRefresh::Incomplete
{
complete_snapshot = false;
}
}
let contributions = if complete_snapshot {
aggregator.install(self.state.blocklist())
} else {
warn!("refresh: incomplete snapshot; keeping existing live blocklist");
Vec::new()
};
summary.total_domains = self.state.blocklist().len();
info!(
fetched = summary.fetched,
not_modified = summary.not_modified,
failed = summary.failed,
total_domains = summary.total_domains,
sources = contributions.len(),
"refresh cycle complete"
);
Ok(summary)
}
pub async fn run(self, token: CancellationToken) {
match self.refresh_once().await {
Ok(summary) => {
info!(
fetched = summary.fetched,
not_modified = summary.not_modified,
failed = summary.failed,
total_domains = summary.total_domains,
"startup refresh complete"
);
}
Err(e) => {
warn!(error = %e, "startup refresh failed — will retry at next interval");
}
}
loop {
let interval_secs = self.state.settings().blocklist_refresh_interval;
let interval = Duration::from_secs(u64::from(interval_secs)).max(MIN_INTERVAL);
tokio::select! {
() = tokio::time::sleep(interval) => {
match self.refresh_once().await {
Ok(summary) => {
info!(
fetched = summary.fetched,
not_modified = summary.not_modified,
failed = summary.failed,
total_domains = summary.total_domains,
"periodic refresh complete"
);
}
Err(e) => {
warn!(error = %e, "periodic refresh failed — will retry at next interval");
}
}
}
() = self.notify.notified() => {
match self.refresh_once().await {
Ok(summary) => {
info!(
fetched = summary.fetched,
not_modified = summary.not_modified,
failed = summary.failed,
total_domains = summary.total_domains,
"on-demand refresh complete"
);
}
Err(e) => {
warn!(error = %e, "on-demand refresh failed");
}
}
}
() = token.cancelled() => {
info!("blocklist scheduler shutting down");
break;
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tempfile::TempDir;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::*;
use crate::{
blocklist::fetch::Fetcher,
storage::{
Db,
blocklists::{BlocklistFormat, BlocklistRepository, NewBlocklist},
},
};
async fn open_db() -> (TempDir, Db) {
let dir = TempDir::new().expect("temp dir");
let path = dir.path().join("test.db");
let db = Db::connect(&path).await.expect("connect");
(dir, db)
}
fn make_scheduler(db: &Db, state: Arc<ResolverState>) -> BlocklistScheduler {
BlocklistScheduler::new(
db.blocklists(),
state,
Fetcher::new().with_timeout(Duration::from_secs(5)),
)
}
fn hosts_source(url: &str) -> NewBlocklist {
NewBlocklist {
url: url.to_owned(),
format: BlocklistFormat::Hosts,
enabled: true,
}
}
#[tokio::test]
async fn load_from_cache_builds_set_without_network() {
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = db.blocklists();
let src = repo
.insert(hosts_source("https://offline.example.com/hosts"))
.await
.expect("insert");
let body = b"0.0.0.0 ads.example.com\n0.0.0.0 tracker.example.org\n";
repo.save_cache(src.id, body).await.expect("save_cache");
assert!(state.blocklist().is_empty());
let scheduler = make_scheduler(&db, Arc::clone(&state));
scheduler.load_from_cache().await;
let ads: crate::codec::name::Name = "ads.example.com".parse().unwrap();
let tracker: crate::codec::name::Name = "tracker.example.org".parse().unwrap();
assert!(
state.blocklist().contains(&ads),
"ads.example.com must be blocked after cache load"
);
assert!(
state.blocklist().contains(&tracker),
"tracker.example.org must be blocked after cache load"
);
assert_eq!(state.blocklist().len(), 2);
}
#[tokio::test]
async fn refresh_once_200_installs_domains_and_persists_metadata() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(
ResponseTemplate::new(200)
.set_body_bytes(
b"0.0.0.0 ads.example.com\n0.0.0.0 tracker.example.org\n".to_vec(),
)
.insert_header("etag", r#""v1""#),
)
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = db.blocklists();
let src = repo.insert(hosts_source(&url)).await.expect("insert");
let scheduler = make_scheduler(&db, Arc::clone(&state));
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.fetched, 1);
assert_eq!(summary.not_modified, 0);
assert_eq!(summary.failed, 0);
assert_eq!(summary.total_domains, 2);
let ads: crate::codec::name::Name = "ads.example.com".parse().unwrap();
let tracker: crate::codec::name::Name = "tracker.example.org".parse().unwrap();
assert!(state.blocklist().contains(&ads));
assert!(state.blocklist().contains(&tracker));
let cached = repo
.load_cache(src.id)
.await
.expect("load_cache")
.expect("cache must be Some after refresh");
assert!(!cached.content.is_empty());
let rows = repo.list().await.expect("list");
let row = rows.iter().find(|r| r.id == src.id).expect("row");
assert_eq!(row.entry_count, 2);
assert!(row.last_updated.is_some(), "last_updated must be set");
assert_eq!(row.etag.as_deref(), Some(r#""v1""#));
}
#[tokio::test]
async fn refresh_once_304_retains_cached_domains() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.and(header(
reqwest::header::IF_NONE_MATCH.as_str(),
r#""etag-v1""#,
))
.respond_with(ResponseTemplate::new(304))
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = db.blocklists();
let src = repo
.insert(NewBlocklist {
url,
format: BlocklistFormat::Hosts,
enabled: true,
})
.await
.expect("insert");
repo.update_refresh_metadata(
src.id,
&RefreshMetadata {
entry_count: 2,
last_updated: 1_700_000_000,
etag: Some(r#""etag-v1""#.to_owned()),
last_modified: None,
},
)
.await
.expect("update meta");
let body = b"0.0.0.0 ads.example.com\n0.0.0.0 tracker.example.org\n";
repo.save_cache(src.id, body).await.expect("save_cache");
let scheduler = make_scheduler(&db, Arc::clone(&state));
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.not_modified, 1);
assert_eq!(summary.fetched, 0);
assert_eq!(summary.failed, 0);
let ads: crate::codec::name::Name = "ads.example.com".parse().unwrap();
let tracker: crate::codec::name::Name = "tracker.example.org".parse().unwrap();
assert!(
state.blocklist().contains(&ads),
"ads must be present after 304"
);
assert!(
state.blocklist().contains(&tracker),
"tracker must be present after 304"
);
}
#[tokio::test]
async fn refresh_once_empty_200_retains_cached_domains() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(ResponseTemplate::new(200).set_body_string(""))
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = db.blocklists();
let src = repo
.insert(NewBlocklist {
url,
format: BlocklistFormat::Hosts,
enabled: true,
})
.await
.expect("insert");
let body = b"0.0.0.0 ads.example.com\n0.0.0.0 tracker.example.org\n";
repo.save_cache(src.id, body).await.expect("save_cache");
let scheduler = make_scheduler(&db, Arc::clone(&state));
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.fetched, 0, "empty body must not count as fetched");
assert_eq!(summary.failed, 1, "empty 200 is a soft failure");
let ads: crate::codec::name::Name = "ads.example.com".parse().unwrap();
let tracker: crate::codec::name::Name = "tracker.example.org".parse().unwrap();
assert!(
state.blocklist().contains(&ads),
"ads must survive an empty 200"
);
assert!(
state.blocklist().contains(&tracker),
"tracker must survive an empty 200"
);
let cached = repo.load_cache(src.id).await.expect("load_cache");
assert_eq!(
cached.map(|c| c.content),
Some(body.to_vec()),
"the empty body must not poison the last-good cache"
);
}
#[tokio::test]
async fn refresh_once_bad_source_retains_cached_domains_and_does_not_sink_good_source() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/good.txt"))
.respond_with(
ResponseTemplate::new(200).set_body_bytes(b"0.0.0.0 good.example.com\n".to_vec()),
)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/bad.txt"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = db.blocklists();
repo.insert(NewBlocklist {
url: format!("{}/good.txt", server.uri()),
format: BlocklistFormat::Hosts,
enabled: true,
})
.await
.expect("insert good");
let bad = repo
.insert(NewBlocklist {
url: format!("{}/bad.txt", server.uri()),
format: BlocklistFormat::DomainList,
enabled: true,
})
.await
.expect("insert bad");
repo.save_cache(bad.id, b"bad-but-cached.example.com\n")
.await
.expect("seed bad cache");
let scheduler = make_scheduler(&db, Arc::clone(&state));
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.fetched, 1, "good source must count as fetched");
assert_eq!(summary.failed, 1, "bad source must count as failed");
let good: crate::codec::name::Name = "good.example.com".parse().unwrap();
let cached: crate::codec::name::Name = "bad-but-cached.example.com".parse().unwrap();
assert!(
state.blocklist().contains(&good),
"good domain must be present"
);
assert!(
state.blocklist().contains(&cached),
"bad source's cached domain must be retained"
);
assert_eq!(state.blocklist().len(), 2, "exactly 2 domains in live set");
}
#[tokio::test]
async fn refresh_once_incomplete_source_keeps_previous_live_snapshot() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/good.txt"))
.respond_with(
ResponseTemplate::new(200)
.set_body_bytes(b"0.0.0.0 newly-fetched.example.com\n".to_vec()),
)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/missing.txt"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = db.blocklists();
let previous: crate::codec::name::Name = "previous.example.com".parse().unwrap();
state
.blocklist()
.store([(previous.clone(), 1)].into_iter().collect());
repo.insert(NewBlocklist {
url: format!("{}/good.txt", server.uri()),
format: BlocklistFormat::Hosts,
enabled: true,
})
.await
.expect("insert good");
repo.insert(NewBlocklist {
url: format!("{}/missing.txt", server.uri()),
format: BlocklistFormat::DomainList,
enabled: true,
})
.await
.expect("insert missing");
let scheduler = make_scheduler(&db, Arc::clone(&state));
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.fetched, 1, "good source still fetched");
assert_eq!(summary.failed, 1, "missing source counts as failed");
assert_eq!(summary.total_domains, 1, "previous snapshot remains live");
let fetched: crate::codec::name::Name = "newly-fetched.example.com".parse().unwrap();
assert!(state.blocklist().contains(&previous));
assert!(
!state.blocklist().contains(&fetched),
"partial refresh result must not be installed"
);
}
#[tokio::test]
async fn run_on_demand_trigger_forces_refresh() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(
ResponseTemplate::new(200).set_body_bytes(b"0.0.0.0 v1.example.com\n".to_vec()),
)
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let repo = db.blocklists();
repo.insert(hosts_source(&url)).await.expect("insert");
let scheduler = BlocklistScheduler::new(
db.blocklists(),
Arc::clone(&state),
Fetcher::new().with_timeout(Duration::from_secs(5)),
);
let trigger = scheduler.trigger();
let token = CancellationToken::new();
let token_clone = token.clone();
let task = tokio::spawn(async move {
scheduler.run(token_clone).await;
});
let v1: crate::codec::name::Name = "v1.example.com".parse().unwrap();
let deadline = tokio::time::Instant::now() + Duration::from_secs(10);
loop {
if state.blocklist().contains(&v1) {
break;
}
assert!(
tokio::time::Instant::now() < deadline,
"timed out waiting for v1 to appear in blocklist"
);
tokio::time::sleep(Duration::from_millis(10)).await;
}
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(
ResponseTemplate::new(200).set_body_bytes(b"0.0.0.0 v2.example.com\n".to_vec()),
)
.with_priority(1)
.mount(&server)
.await;
trigger.trigger();
let v2: crate::codec::name::Name = "v2.example.com".parse().unwrap();
let deadline2 = tokio::time::Instant::now() + Duration::from_secs(10);
loop {
if state.blocklist().contains(&v2) {
break;
}
assert!(
tokio::time::Instant::now() < deadline2,
"timed out waiting for v2 to appear in blocklist after trigger"
);
tokio::time::sleep(Duration::from_millis(10)).await;
}
token.cancel();
tokio::time::timeout(Duration::from_secs(5), task)
.await
.expect("scheduler task timed out on shutdown")
.expect("scheduler task panicked");
}
#[tokio::test]
async fn zero_sources_load_and_refresh_no_panic() {
let (_dir, db) = open_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let scheduler = make_scheduler(&db, Arc::clone(&state));
scheduler.load_from_cache().await;
assert!(state.blocklist().is_empty());
let summary = scheduler.refresh_once().await.expect("refresh_once");
assert_eq!(summary.fetched, 0);
assert_eq!(summary.not_modified, 0);
assert_eq!(summary.failed, 0);
assert_eq!(summary.total_domains, 0);
assert!(state.blocklist().is_empty());
}
}