use std::net::{Ipv4Addr, Ipv6Addr};
use askama::Template;
use askama_web::WebTemplate;
use axum::{
extract::State,
http::{HeaderMap, StatusCode, Uri, header},
response::{IntoResponse, Redirect, Response},
};
use serde::Deserialize;
use crate::{
resolver::state::RuntimeSettings,
storage::{
query_log::QueryLogRepository,
settings::{BlockingMode, SelectionStrategy, Settings, SettingsRepository},
},
web::{
AppState, Chrome,
auth::CurrentUser,
render::{WebError, WebResult},
},
};
impl AppState {
async fn render_settings(
&self,
user: &CurrentUser,
error: Option<String>,
notice: Option<String>,
) -> WebResult<SettingsPageTemplate> {
let s = self.db.settings().get().await?;
Ok(SettingsPageTemplate {
chrome: self.chrome("settings", user).await,
cache_min_ttl: s.cache_min_ttl,
cache_max_ttl: s.cache_max_ttl,
cache_negative_ttl_cap: s.cache_negative_ttl_cap,
cache_capacity: s.cache_capacity,
blocking_mode: s.blocking_mode.as_str(),
custom_block_ipv4: s
.custom_block_ipv4
.map(|i| i.to_string())
.unwrap_or_default(),
custom_block_ipv6: s
.custom_block_ipv6
.map(|i| i.to_string())
.unwrap_or_default(),
blocklist_refresh_interval: s.blocklist_refresh_interval,
query_log_enabled: s.query_log_enabled,
query_log_retention_days: s.query_log_retention_days,
upstream_selection_strategy: s.upstream_selection_strategy.as_str(),
upstream_parallel_fanout: s.upstream_parallel_fanout,
error,
notice,
})
}
pub async fn settings_page(
user: CurrentUser,
State(state): State<AppState>,
) -> WebResult<Response> {
Ok(state
.render_settings(&user, None, None)
.await?
.into_response())
}
pub async fn settings_save(
user: CurrentUser,
State(state): State<AppState>,
axum::Form(form): axum::Form<SettingsForm>,
) -> WebResult<Response> {
match state.apply_settings(form).await {
Ok(()) => Ok(state
.render_settings(&user, None, Some("Settings saved.".to_owned()))
.await?
.into_response()),
Err(WebError::BadRequest(msg)) => {
let page = state.render_settings(&user, Some(msg), None).await?;
Ok((StatusCode::BAD_REQUEST, page).into_response())
}
Err(e) => Err(e),
}
}
pub async fn theme_toggle(
_user: CurrentUser,
State(state): State<AppState>,
headers: HeaderMap,
) -> WebResult<Response> {
let mut settings = state.db.settings().get().await?;
settings.ui_theme = if settings.ui_theme == "dark" {
"light".to_owned()
} else {
"dark".to_owned()
};
state.db.settings().update(&settings).await?;
let back = headers
.get(header::REFERER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<Uri>().ok())
.and_then(|u| u.path_and_query().map(|pq| pq.as_str().to_owned()))
.unwrap_or_else(|| "/".to_owned());
Ok(Redirect::to(&back).into_response())
}
pub async fn settings_clear_log(
user: CurrentUser,
State(state): State<AppState>,
) -> WebResult<Response> {
state.db.query_log().clear_all().await?;
Ok(state
.render_settings(&user, None, Some("Query log cleared.".to_owned()))
.await?
.into_response())
}
async fn apply_settings(&self, form: SettingsForm) -> WebResult<()> {
let ui_theme = self.db.settings().get().await?.ui_theme;
let settings = form.into_settings(ui_theme)?;
self.db.settings().update(&settings).await?;
self.resolver
.store_settings(RuntimeSettings::from(&settings));
self.resolver
.cache()
.set_ttl_bounds(settings.cache_min_ttl, settings.cache_max_ttl);
self.rebuild_upstream_pool().await?;
Ok(())
}
}
#[derive(Debug, Deserialize)]
pub struct SettingsForm {
cache_min_ttl: u32,
cache_max_ttl: u32,
cache_negative_ttl_cap: u32,
cache_capacity: u64,
blocking_mode: String,
#[serde(default)]
custom_block_ipv4: String,
#[serde(default)]
custom_block_ipv6: String,
blocklist_refresh_interval: u32,
#[serde(default)]
query_log_enabled: Option<String>,
query_log_retention_days: u32,
upstream_selection_strategy: String,
upstream_parallel_fanout: u32,
}
impl SettingsForm {
fn into_settings(self, ui_theme: String) -> WebResult<Settings> {
if self.cache_min_ttl > self.cache_max_ttl {
return Err(WebError::bad_request(
"Minimum cache TTL must not exceed the maximum.",
));
}
if self.cache_capacity == 0 {
return Err(WebError::bad_request("Cache capacity must be at least 1."));
}
let blocking_mode: BlockingMode = self
.blocking_mode
.parse()
.map_err(|_| WebError::bad_request("Invalid blocking mode."))?;
let (custom_block_ipv4, custom_block_ipv6) = if blocking_mode == BlockingMode::Custom {
(
parse_opt_ip::<Ipv4Addr>(&self.custom_block_ipv4, "IPv4")?,
parse_opt_ip::<Ipv6Addr>(&self.custom_block_ipv6, "IPv6")?,
)
} else {
(
self.custom_block_ipv4.trim().parse::<Ipv4Addr>().ok(),
self.custom_block_ipv6.trim().parse::<Ipv6Addr>().ok(),
)
};
if self.query_log_retention_days == 0 {
return Err(WebError::bad_request(
"Query-log retention must be at least 1 day.",
));
}
let upstream_selection_strategy: SelectionStrategy = self
.upstream_selection_strategy
.parse()
.map_err(|_| WebError::bad_request("Invalid upstream selection strategy."))?;
if self.upstream_parallel_fanout < 1 {
return Err(WebError::bad_request(
"Parallel fan-out must be at least 1.",
));
}
Ok(Settings {
cache_min_ttl: self.cache_min_ttl,
cache_max_ttl: self.cache_max_ttl,
cache_negative_ttl_cap: self.cache_negative_ttl_cap,
cache_capacity: self.cache_capacity,
blocking_mode,
custom_block_ipv4,
custom_block_ipv6,
blocklist_refresh_interval: self.blocklist_refresh_interval,
ui_theme,
query_log_enabled: self.query_log_enabled.is_some(),
query_log_retention_days: self.query_log_retention_days,
upstream_selection_strategy,
upstream_parallel_fanout: self.upstream_parallel_fanout,
})
}
}
fn parse_opt_ip<T: std::str::FromStr>(s: &str, label: &str) -> Result<Option<T>, WebError> {
let s = s.trim();
if s.is_empty() {
return Ok(None);
}
s.parse::<T>()
.map(Some)
.map_err(|_| WebError::bad_request(format!("'{s}' is not a valid {label} address.")))
}
#[derive(Template, WebTemplate)]
#[template(path = "settings.html")]
struct SettingsPageTemplate {
chrome: Chrome,
cache_min_ttl: u32,
cache_max_ttl: u32,
cache_negative_ttl_cap: u32,
cache_capacity: u64,
blocking_mode: &'static str,
custom_block_ipv4: String,
custom_block_ipv6: String,
blocklist_refresh_interval: u32,
query_log_enabled: bool,
query_log_retention_days: u32,
upstream_selection_strategy: &'static str,
upstream_parallel_fanout: u32,
error: Option<String>,
notice: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::synth::BlockMode;
use tempfile::TempDir;
async fn state() -> (TempDir, AppState) {
let (dir, db) = crate::test_support::temp_db().await;
(dir, AppState::for_test(db).await)
}
fn base_form() -> SettingsForm {
SettingsForm {
cache_min_ttl: 10,
cache_max_ttl: 3600,
cache_negative_ttl_cap: 300,
cache_capacity: 50_000,
blocking_mode: "nxdomain".to_owned(),
custom_block_ipv4: String::new(),
custom_block_ipv6: String::new(),
blocklist_refresh_interval: 7200,
query_log_enabled: Some("on".to_owned()),
query_log_retention_days: 30,
upstream_selection_strategy: "random".to_owned(),
upstream_parallel_fanout: 2,
}
}
#[tokio::test]
async fn apply_settings_persists_and_updates_snapshot() {
let (_d, st) = state().await;
st.apply_settings(base_form()).await.expect("apply");
let s = st.db.settings().get().await.unwrap();
assert_eq!(s.cache_max_ttl, 3600);
assert_eq!(s.blocking_mode, BlockingMode::NxDomain);
assert_eq!(s.ui_theme, "auto");
assert_eq!(st.resolver.settings().block_mode, BlockMode::NxDomain);
assert_eq!(st.resolver.settings().cache_max_ttl, 3600);
}
#[tokio::test]
async fn parallel_strategy_rebuilds_pool_in_parallel_mode() {
let (_d, st) = state().await;
assert_eq!(st.upstream_pool.load().parallel_fanout(), None);
let mut f = base_form();
f.upstream_selection_strategy = "parallel".to_owned();
f.upstream_parallel_fanout = 3;
st.apply_settings(f).await.expect("apply");
let s = st.db.settings().get().await.unwrap();
assert_eq!(s.upstream_selection_strategy, SelectionStrategy::Parallel);
assert_eq!(
st.resolver.settings().upstream_selection_strategy,
SelectionStrategy::Parallel
);
assert_eq!(st.upstream_pool.load().parallel_fanout(), Some(3));
}
#[tokio::test]
async fn invalid_fanout_is_rejected() {
let (_d, st) = state().await;
let mut f = base_form();
f.upstream_parallel_fanout = 0;
assert!(matches!(
st.apply_settings(f).await,
Err(WebError::BadRequest(_))
));
}
#[tokio::test]
async fn min_greater_than_max_is_rejected() {
let (_d, st) = state().await;
let mut f = base_form();
f.cache_min_ttl = 5000;
f.cache_max_ttl = 100;
assert!(matches!(
st.apply_settings(f).await,
Err(WebError::BadRequest(_))
));
}
#[tokio::test]
async fn custom_mode_with_ips_round_trips() {
let (_d, st) = state().await;
let mut f = base_form();
f.blocking_mode = "custom".to_owned();
f.custom_block_ipv4 = "203.0.113.1".to_owned();
st.apply_settings(f).await.expect("apply custom");
let s = st.db.settings().get().await.unwrap();
assert_eq!(s.blocking_mode, BlockingMode::Custom);
assert_eq!(s.custom_block_ipv4, Some("203.0.113.1".parse().unwrap()));
}
#[tokio::test]
async fn invalid_custom_ip_in_custom_mode_is_rejected() {
let (_d, st) = state().await;
let mut f = base_form();
f.blocking_mode = "custom".to_owned();
f.custom_block_ipv4 = "not-an-ip".to_owned();
assert!(matches!(
st.apply_settings(f).await,
Err(WebError::BadRequest(_))
));
}
#[tokio::test]
async fn invalid_custom_ip_ignored_when_mode_not_custom() {
let (_d, st) = state().await;
let mut f = base_form(); f.custom_block_ipv4 = "not-an-ip".to_owned();
st.apply_settings(f).await.expect("apply");
let s = st.db.settings().get().await.unwrap();
assert_eq!(s.custom_block_ipv4, None);
}
#[tokio::test]
async fn valid_custom_ip_preserved_across_non_custom_save() {
let (_d, st) = state().await;
let mut f = base_form(); f.custom_block_ipv4 = "203.0.113.9".to_owned();
st.apply_settings(f).await.expect("apply");
let s = st.db.settings().get().await.unwrap();
assert_eq!(s.custom_block_ipv4, Some("203.0.113.9".parse().unwrap()));
}
#[tokio::test]
async fn query_log_fields_round_trip() {
let (_d, st) = state().await;
let mut f = base_form();
f.query_log_enabled = None;
f.query_log_retention_days = 7;
st.apply_settings(f).await.expect("apply");
let s = st.db.settings().get().await.unwrap();
assert!(!s.query_log_enabled, "unticked checkbox disables logging");
assert_eq!(s.query_log_retention_days, 7);
assert!(!st.resolver.settings().query_log_enabled);
assert_eq!(st.resolver.settings().query_log_retention_days, 7);
}
#[tokio::test]
async fn zero_retention_is_rejected() {
let (_d, st) = state().await;
let mut f = base_form();
f.query_log_retention_days = 0;
assert!(matches!(
st.apply_settings(f).await,
Err(WebError::BadRequest(_))
));
}
#[tokio::test]
async fn clear_log_empties_the_table() {
use crate::{
resolver::pipeline::Outcome,
storage::query_log::{QueryLogRecord, QueryLogRepository},
};
let (_d, st) = state().await;
let repo = st.db.query_log();
repo.insert_batch(&[QueryLogRecord {
id: 0,
ts: 1,
client: "10.0.0.1".to_owned(),
qname: "x.test.".to_owned(),
qtype: "A".to_owned(),
outcome: Outcome::Forwarded,
rcode: Some(0),
upstream: None,
latency_ms: 1,
blocklist_id: None,
}])
.await
.expect("seed a row");
repo.clear_all().await.expect("clear");
assert!(repo.page(None, 10).await.expect("page").is_empty());
}
}