use askama::Template;
use askama_web::WebTemplate;
use axum::{
extract::{Query, State},
response::{IntoResponse, Redirect, Response},
};
use datastar::prelude::PatchElements;
use serde::Deserialize;
use crate::{
codec::name::Name,
storage::lists::{AllowlistRepository, BlacklistRepository},
web::{
AppState,
auth::CurrentUser,
render::{DomainDisplay, WebError, WebResult, datastar_response},
},
};
impl AppState {
async fn reload_blacklist(&self) -> WebResult<()> {
let names = self.db.blacklist().load_all().await?;
self.resolver.blacklist().store(names.into_iter().collect());
Ok(())
}
async fn reload_allowlist(&self) -> WebResult<()> {
let names = self.db.allowlist().load_all().await?;
self.resolver.allowlist().store(names.into_iter().collect());
Ok(())
}
pub(crate) async fn add_to_blacklist(&self, domain: &str) -> WebResult<()> {
self.db.blacklist().add(domain).await?;
self.reload_blacklist().await
}
pub(crate) async fn remove_from_blacklist(&self, domain: &str) -> WebResult<()> {
self.db.blacklist().remove(domain).await?;
self.reload_blacklist().await
}
pub(crate) async fn add_to_allowlist(&self, domain: &str) -> WebResult<()> {
self.db.allowlist().add(domain).await?;
self.reload_allowlist().await
}
pub(crate) async fn remove_from_allowlist(&self, domain: &str) -> WebResult<()> {
self.db.allowlist().remove(domain).await?;
self.reload_allowlist().await
}
pub(crate) async fn unblock(&self, domain: &str) -> WebResult<()> {
let on_blacklist = domain
.parse::<Name>()
.is_ok_and(|n| self.resolver.blacklist().contains(&n));
if on_blacklist {
self.remove_from_blacklist(domain).await
} else {
self.add_to_allowlist(domain).await
}
}
pub async fn log_block(
_user: CurrentUser,
State(state): State<AppState>,
Query(q): Query<ActionQuery>,
) -> WebResult<Response> {
state.add_to_blacklist(&q.domain).await?;
Ok(toast(format!(
"Blocked {} — it is blocked on the next query.",
q.domain.display_domain()
)))
}
pub async fn log_unblock(
_user: CurrentUser,
State(state): State<AppState>,
Query(q): Query<ActionQuery>,
) -> WebResult<Response> {
state.unblock(&q.domain).await?;
Ok(toast(format!(
"Unblocked {} — it resolves on the next query.",
q.domain.display_domain()
)))
}
}
#[derive(Debug, Clone, Copy)]
enum ListKind {
Blacklist,
Allowlist,
}
impl ListKind {
fn title(self) -> &'static str {
match self {
Self::Blacklist => "Blacklist",
Self::Allowlist => "Allowlist",
}
}
fn active(self) -> &'static str {
match self {
Self::Blacklist => "blacklist",
Self::Allowlist => "allowlist",
}
}
fn base_path(self) -> &'static str {
match self {
Self::Blacklist => "/blacklist",
Self::Allowlist => "/allowlist",
}
}
fn description(self) -> &'static str {
match self {
Self::Blacklist => {
"Domains you explicitly block. Highest precedence — wins over the allowlist and blocklists."
}
Self::Allowlist => {
"Domains you explicitly allow. An exception that suppresses blocklist matches (but not the blacklist)."
}
}
}
}
impl AppState {
async fn list_entries(&self, kind: ListKind) -> Result<Vec<String>, WebError> {
let entries = match kind {
ListKind::Blacklist => self.db.blacklist().list().await?,
ListKind::Allowlist => self.db.allowlist().list().await?,
};
Ok(entries
.into_iter()
.map(|e| e.domain.display_domain().to_owned())
.collect())
}
async fn list_add(&self, kind: ListKind, domain: &str) -> WebResult<()> {
match kind {
ListKind::Blacklist => self.add_to_blacklist(domain).await,
ListKind::Allowlist => self.add_to_allowlist(domain).await,
}
}
async fn list_remove(&self, kind: ListKind, domain: &str) -> WebResult<()> {
match kind {
ListKind::Blacklist => self.remove_from_blacklist(domain).await,
ListKind::Allowlist => self.remove_from_allowlist(domain).await,
}
}
async fn render_list(
&self,
user: &CurrentUser,
kind: ListKind,
error: Option<String>,
) -> WebResult<ListPageTemplate> {
Ok(ListPageTemplate {
chrome: self.chrome(kind.active(), user).await,
title: kind.title(),
description: kind.description(),
base_path: kind.base_path(),
entries: self.list_entries(kind).await?,
error,
})
}
async fn list_page(&self, user: &CurrentUser, kind: ListKind) -> WebResult<Response> {
Ok(self.render_list(user, kind, None).await?.into_response())
}
async fn list_add_handler(
&self,
user: &CurrentUser,
kind: ListKind,
domain: &str,
) -> WebResult<Response> {
match self.list_add(kind, domain).await {
Ok(()) => Ok(Redirect::to(kind.base_path()).into_response()),
Err(WebError::BadRequest(msg)) => {
let page = self.render_list(user, kind, Some(msg)).await?;
Ok((axum::http::StatusCode::BAD_REQUEST, page).into_response())
}
Err(e) => Err(e),
}
}
async fn list_remove_handler(&self, kind: ListKind, domain: &str) -> WebResult<Response> {
self.list_remove(kind, domain).await?;
Ok(Redirect::to(kind.base_path()).into_response())
}
pub async fn blacklist_page(
user: CurrentUser,
State(state): State<AppState>,
) -> WebResult<Response> {
state.list_page(&user, ListKind::Blacklist).await
}
pub async fn blacklist_add(
user: CurrentUser,
State(state): State<AppState>,
axum::Form(form): axum::Form<DomainForm>,
) -> WebResult<Response> {
state
.list_add_handler(&user, ListKind::Blacklist, &form.domain)
.await
}
pub async fn blacklist_remove(
_user: CurrentUser,
State(state): State<AppState>,
axum::Form(form): axum::Form<DomainForm>,
) -> WebResult<Response> {
state
.list_remove_handler(ListKind::Blacklist, &form.domain)
.await
}
pub async fn allowlist_page(
user: CurrentUser,
State(state): State<AppState>,
) -> WebResult<Response> {
state.list_page(&user, ListKind::Allowlist).await
}
pub async fn allowlist_add(
user: CurrentUser,
State(state): State<AppState>,
axum::Form(form): axum::Form<DomainForm>,
) -> WebResult<Response> {
state
.list_add_handler(&user, ListKind::Allowlist, &form.domain)
.await
}
pub async fn allowlist_remove(
_user: CurrentUser,
State(state): State<AppState>,
axum::Form(form): axum::Form<DomainForm>,
) -> WebResult<Response> {
state
.list_remove_handler(ListKind::Allowlist, &form.domain)
.await
}
}
#[derive(Debug, Deserialize)]
pub struct DomainForm {
pub domain: String,
}
#[derive(Template, WebTemplate)]
#[template(path = "list_page.html")]
struct ListPageTemplate {
chrome: crate::web::Chrome,
title: &'static str,
description: &'static str,
base_path: &'static str,
entries: Vec<String>,
error: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ActionQuery {
pub domain: String,
}
fn toast(message: String) -> Response {
let html = format!(
r#"<div id="sgt-toast" class="sgt-toast sgt-notice--ok" role="status">{}</div>"#,
crate::web::render::html_escape(&message)
);
datastar_response(vec![PatchElements::new(html).write_as_axum_sse_event()])
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::name::Name;
use tempfile::TempDir;
async fn state() -> (TempDir, AppState) {
let (dir, db) = crate::test_support::temp_db().await;
let st = AppState::for_test(db).await;
(dir, st)
}
fn name(s: &str) -> Name {
s.parse().unwrap()
}
#[tokio::test]
async fn add_to_blacklist_writes_through_and_swaps() {
let (_d, st) = state().await;
assert!(!st.resolver.blacklist().contains(&name("ads.example.com")));
st.add_to_blacklist("ads.example.com").await.expect("add");
assert!(st.resolver.blacklist().contains(&name("ads.example.com")));
let names = st.db.blacklist().load_all().await.unwrap();
assert!(names.contains(&name("ads.example.com")));
}
#[tokio::test]
async fn add_to_allowlist_writes_through_and_swaps() {
let (_d, st) = state().await;
st.add_to_allowlist("safe.example.com").await.expect("add");
assert!(st.resolver.allowlist().contains(&name("safe.example.com")));
}
#[tokio::test]
async fn invalid_domain_is_bad_request() {
let (_d, st) = state().await;
let err = st.add_to_blacklist("not..valid").await.unwrap_err();
assert!(matches!(err, WebError::BadRequest(_)));
}
#[tokio::test]
async fn unblock_removes_an_admin_blacklisted_domain() {
let (_d, st) = state().await;
st.add_to_blacklist("ads.example.com").await.expect("add");
assert!(st.resolver.blacklist().contains(&name("ads.example.com")));
st.unblock("ads.example.com").await.expect("unblock");
assert!(!st.resolver.blacklist().contains(&name("ads.example.com")));
assert!(!st.resolver.allowlist().contains(&name("ads.example.com")));
}
#[tokio::test]
async fn unblock_allowlists_a_blocklist_blocked_domain() {
let (_d, st) = state().await;
st.unblock("tracker.example.com").await.expect("unblock");
assert!(
st.resolver
.allowlist()
.contains(&name("tracker.example.com"))
);
}
}