use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use std::net::IpAddr;
use dashmap::DashMap;
use std::convert::Infallible;
use std::time::Duration;
use axum::{
extract::{Path, Query, State, rejection::QueryRejection},
http::{HeaderValue, Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
response::sse::{Event, KeepAlive, Sse},
Json as JsonExtract,
Router,
routing::{delete, get, post},
};
use arc_swap::ArcSwap;
use futures_util::stream;
use serde::Deserialize;
use tokio::sync::Mutex;
use tracing::{error, info, warn};
use crate::dns::{BlacklistAction, ZoneAction, local::{LocalZoneSet, parse_local_data}};
use crate::feeds::{self, FeedFormat, add_feed, builtin_presets, remove_feed, update_all_feeds, update_one_feed};
use crate::logbuffer::{LogAction, LogQuery, SharedLogBuffer};
use crate::store::{self, DnsEntry, DnsType, BlacklistEntry};
use crate::config::parser::{TlsConfig, UnboundConfig};
use crate::stats::Stats;
use crate::audit::{AuditEvent, AuditLogger};
use crate::sync::{SyncJournal, SyncOp};
use crate::upstreams::SharedUpstreams;
const MAX_API_TTL: u32 = 86_400;
struct ApiJson<T>(T);
#[axum::async_trait]
impl<T, S> axum::extract::FromRequest<S> for ApiJson<T>
where
T: serde::de::DeserializeOwned,
S: Send + Sync,
{
type Rejection = (StatusCode, axum::Json<serde_json::Value>);
async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
match axum::Json::<T>::from_request(req, state).await {
Ok(axum::Json(val)) => Ok(ApiJson(val)),
Err(rejection) => {
use axum::extract::rejection::JsonRejection;
let (status, msg) = match rejection {
JsonRejection::JsonDataError(e) => (StatusCode::UNPROCESSABLE_ENTITY, e.to_string()),
JsonRejection::JsonSyntaxError(e) => (StatusCode::BAD_REQUEST, e.to_string()),
JsonRejection::MissingJsonContentType(e) => (StatusCode::UNSUPPORTED_MEDIA_TYPE, e.to_string()),
e => (StatusCode::BAD_REQUEST, e.to_string()),
};
Err((status, axum::Json(serde_json::json!({
"error": "INVALID_REQUEST",
"details": msg
}))))
}
}
}
}
static API_KEY: std::sync::OnceLock<ArcSwap<String>> = std::sync::OnceLock::new();
static AUTH_FAILURES: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
const MAX_BODY_BYTES: usize = 65_536;
const API_RATE_LIMIT_RPS: u64 = 30;
const API_RATE_BURST: u64 = 60;
const MAX_DNS_ENTRIES: usize = 10_000;
const MAX_BLACKLIST_ENTRIES: usize = 100_000;
const MAX_FEEDS: usize = 100;
pub fn init_api_key(config_key: Option<String>) -> String {
let key = crate::hsm::api_key().map(|k| k.to_string())
.or_else(|| std::env::var("RUNBOUND_API_KEY").ok())
.or(config_key)
.unwrap_or_else(|| {
format!("{}{}",
uuid::Uuid::new_v4().simple(),
uuid::Uuid::new_v4().simple())
});
API_KEY.get_or_init(|| ArcSwap::from(Arc::new(key.clone())));
key
}
pub fn get_api_key() -> Arc<String> {
API_KEY.get()
.map(|s| s.load_full())
.unwrap_or_else(|| Arc::new(String::new()))
}
pub fn rotate_api_key(new_key: String) {
if let Some(swap) = API_KEY.get() {
swap.store(Arc::new(new_key));
}
}
struct ApiBucket { tokens: u64, last: Instant }
#[derive(Clone)]
pub struct ApiRateLimiter(Arc<DashMap<IpAddr, ApiBucket, ahash::RandomState>>);
impl ApiRateLimiter {
fn new() -> Self {
Self(Arc::new(DashMap::with_hasher(ahash::RandomState::default())))
}
pub fn new_public() -> Self { Self::new() }
#[inline]
fn check(&self, ip: IpAddr) -> bool {
let now = Instant::now();
let mut b = self.0.entry(ip).or_insert(ApiBucket { tokens: API_RATE_BURST, last: now });
let elapsed_ms = now.duration_since(b.last).as_millis() as u64;
if elapsed_ms >= 1000 {
b.tokens = API_RATE_BURST; b.last = now;
} else {
let new = (API_RATE_LIMIT_RPS * elapsed_ms) / 1000;
if new > 0 { b.tokens = (b.tokens + new).min(API_RATE_BURST); b.last = now; }
}
if b.tokens > 0 { b.tokens -= 1; true } else { false }
}
}
#[derive(Clone)]
pub struct AppState {
pub zones: Arc<ArcSwap<LocalZoneSet>>,
pub zones_mutex: Arc<Mutex<()>>,
pub tls_cfg: Arc<TlsConfig>,
pub rate_limiter: ApiRateLimiter,
pub stats: Arc<Stats>,
pub cfg: Arc<UnboundConfig>,
pub cfg_path: String,
pub log_buffer: SharedLogBuffer,
pub upstreams: SharedUpstreams,
pub sync_journal: Option<Arc<SyncJournal>>,
pub slave_mode: bool,
pub base_dir: Arc<PathBuf>,
pub audit: AuditLogger,
}
#[derive(Debug, Deserialize)]
pub struct AddDnsRequest {
pub name: String,
#[serde(rename = "type")]
pub entry_type: DnsType,
#[serde(default = "default_ttl_i64")]
pub ttl: i64,
pub value: Option<String>,
pub priority: Option<u16>,
pub weight: Option<u16>,
pub port: Option<u16>,
pub flags: Option<u8>,
pub tag: Option<String>,
pub order: Option<u16>,
pub preference_naptr: Option<u16>,
pub flags_naptr: Option<String>,
pub services: Option<String>,
pub regexp: Option<String>,
pub replacement: Option<String>,
pub algorithm: Option<u8>,
pub fp_type: Option<u8>,
pub fingerprint: Option<String>,
pub cert_usage: Option<u8>,
pub selector: Option<u8>,
pub matching_type: Option<u8>,
pub cert_data: Option<String>,
pub description: Option<String>,
}
fn default_ttl_i64() -> i64 { 3600 }
#[derive(Debug, Deserialize)]
pub struct AddFeedRequest {
pub name: String,
pub url: String,
#[serde(default)]
pub format: FeedFormat,
#[serde(default)]
pub action: BlacklistAction,
pub description: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct AddBlacklistRequest {
pub domain: String,
#[serde(default)]
pub action: BlacklistAction,
pub description: Option<String>,
}
async fn security_middleware(
State(state): State<AppState>,
req: Request<axum::body::Body>,
next: Next,
) -> Response {
if let Some(cl) = req.headers().get(axum::http::header::CONTENT_LENGTH) {
let len: usize = cl.to_str().unwrap_or("0").parse().unwrap_or(0);
if len > MAX_BODY_BYTES {
return (StatusCode::PAYLOAD_TOO_LARGE, axum::Json(serde_json::json!({
"error": "REQUEST_TOO_LARGE",
"details": format!("Body exceeds {} bytes", MAX_BODY_BYTES)
}))).into_response();
}
}
let client_ip: IpAddr = IpAddr::from([127, 0, 0, 1]);
if !state.rate_limiter.check(client_ip) {
warn!(%client_ip, "API rate limited");
return (StatusCode::TOO_MANY_REQUESTS,
[(axum::http::header::RETRY_AFTER, "1")],
"Rate limit exceeded").into_response();
}
let path = req.uri().path();
{
let auth = req.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let key = get_api_key();
let expected = format!("Bearer {}", key.as_str());
if !constant_time_eq(auth.as_bytes(), expected.as_bytes()) {
let failures = AUTH_FAILURES.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
state.audit.send(AuditEvent::AuthFailure { path: path.to_string() });
if failures.is_multiple_of(10) {
warn!(failures, %path, "Repeated API authentication failures — check RUNBOUND_API_KEY");
}
if failures >= 50 {
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
return (StatusCode::UNAUTHORIZED,
[(axum::http::header::WWW_AUTHENTICATE, "Bearer realm=\"runbound\"")],
"Unauthorized").into_response();
}
AUTH_FAILURES.store(0, std::sync::atomic::Ordering::Relaxed);
}
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert("x-content-type-options", HeaderValue::from_static("nosniff"));
headers.insert("x-frame-options", HeaderValue::from_static("DENY"));
headers.insert("x-xss-protection", HeaderValue::from_static("1; mode=block"));
headers.insert("referrer-policy", HeaderValue::from_static("no-referrer"));
headers.insert("content-security-policy", HeaderValue::from_static("default-src 'none'"));
headers.insert("cache-control", HeaderValue::from_static("no-store"));
headers.insert("x-accel-buffering", HeaderValue::from_static("no"));
response
}
#[inline(always)]
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
use subtle::ConstantTimeEq;
let len_mismatch = u8::from(a.len() != b.len());
let diff: u8 = b.iter().enumerate()
.fold(len_mismatch, |acc, (i, &bi)| {
acc | (a.get(i).copied().unwrap_or(0) ^ bi)
});
diff.ct_eq(&0u8).into()
}
async fn slave_guard_middleware(
State(state): State<AppState>,
req: Request<axum::body::Body>,
next: Next,
) -> Response {
if state.slave_mode && req.method() != axum::http::Method::GET {
return (StatusCode::SERVICE_UNAVAILABLE, JsonExtract(serde_json::json!({
"error": "READ_ONLY",
"details": "This node is a slave replica — write operations are disabled",
}))).into_response();
}
next.run(req).await
}
pub fn router(state: AppState) -> Router {
Router::new()
.route("/help", get(help_handler))
.route("/health", get(health_handler))
.route("/stats", get(stats_handler))
.route("/stats/stream", get(stats_stream_handler))
.route("/config", get(config_handler))
.route("/reload", post(reload_handler))
.route("/dns", get(list_dns_handler).post(add_dns_handler))
.route("/dns/:id", delete(delete_dns_handler))
.route("/blacklist", get(list_blacklist_handler).post(add_blacklist_handler))
.route("/blacklist/:id", delete(delete_blacklist_handler))
.route("/feeds", get(get_feeds_handler).post(add_feed_handler))
.route("/feeds/presets", get(feed_presets_handler))
.route("/feeds/update", post(update_feeds_handler))
.route("/feeds/:id", delete(delete_feed_handler))
.route("/feeds/:id/update", post(update_one_feed_handler))
.route("/tls", get(tls_status_handler))
.route("/upstreams", get(upstreams_handler))
.route("/logs", get(logs_handler).delete(clear_logs_handler))
.route("/audit/tail", get(audit_tail_handler))
.route("/metrics", get(metrics_handler))
.route("/rotate-key", post(rotate_key_handler))
.layer(middleware::from_fn_with_state(state.clone(), slave_guard_middleware))
.layer(middleware::from_fn_with_state(state.clone(), security_middleware))
.layer(axum::extract::DefaultBodyLimit::max(MAX_BODY_BYTES))
.with_state(state)
}
async fn help_handler() -> impl IntoResponse {
JsonExtract(serde_json::json!({
"service": "Runbound DNS",
"version": env!("CARGO_PKG_VERSION"),
"protocols": ["DNS/UDP:53","DNS/TCP:53","DoT:853","DoH:443","DoQ:853/UDP"],
"rfcs": ["RFC1034","RFC1035","RFC2782","RFC4033","RFC4034","RFC4035","RFC6698","RFC6891","RFC7858","RFC8484","RFC9250"],
"endpoints": [
{"method":"GET", "path":"/help", "description":"API documentation"},
{"method":"GET", "path":"/health", "description":"Liveness check"},
{"method":"GET", "path":"/stats", "description":"Query statistics snapshot"},
{"method":"GET", "path":"/stats/stream", "description":"Live stats as Server-Sent Events (1-second interval)"},
{"method":"GET", "path":"/config", "description":"Running configuration"},
{"method":"POST", "path":"/reload", "description":"Hot-reload zones and blacklist from disk"},
{"method":"GET", "path":"/dns", "description":"List all local DNS entries"},
{"method":"POST", "path":"/dns", "description":"Add a local DNS entry (A/AAAA/CNAME/TXT/MX/SRV/CAA/PTR/NAPTR/SSHFP/TLSA/NS)"},
{"method":"DELETE", "path":"/dns/:id", "description":"Remove a DNS entry by UUID"},
{"method":"GET", "path":"/blacklist", "description":"List blacklist entries"},
{"method":"POST", "path":"/blacklist", "description":"Add a domain to the blacklist (refuse/nxdomain)"},
{"method":"DELETE", "path":"/blacklist/:id", "description":"Remove a blacklist entry"},
{"method":"GET", "path":"/feeds", "description":"List feed subscriptions"},
{"method":"POST", "path":"/feeds", "description":"Subscribe to a remote blocklist"},
{"method":"DELETE", "path":"/feeds/:id", "description":"Remove a feed subscription"},
{"method":"POST", "path":"/feeds/update", "description":"Refresh all feeds"},
{"method":"POST", "path":"/feeds/:id/update", "description":"Refresh one feed"},
{"method":"GET", "path":"/feeds/presets", "description":"List pre-configured blocklists"},
{"method":"GET", "path":"/tls", "description":"DoT/DoH/DoQ TLS status"},
{"method":"GET", "path":"/upstreams", "description":"Upstream DNS resolver health"},
{"method":"GET", "path":"/logs", "description":"Recent query log (newest first) — ?limit=100&page=0&action=blocked&client=1.2.3.4&since=<unix>"},
{"method":"DELETE", "path":"/logs", "description":"Clear the in-memory query log ring buffer (GDPR right-to-erasure)"},
{"method":"GET", "path":"/audit/tail", "description":"Last N audit log entries — ?n=100"},
{"method":"GET", "path":"/metrics", "description":"Prometheus/OpenMetrics exposition (text/plain; version=0.0.4)"},
{"method":"POST", "path":"/rotate-key", "description":"Atomically rotate API key — reads new key from RUNBOUND_API_KEY env var"},
]
}))
}
async fn health_handler(State(s): State<AppState>) -> impl IntoResponse {
let snap = s.stats.snapshot();
JsonExtract(serde_json::json!({
"status": "ok",
"uptime_secs": snap.uptime_secs,
"queries": snap.total,
"hsm": crate::hsm::is_active(),
}))
}
async fn stats_handler(State(s): State<AppState>) -> impl IntoResponse {
JsonExtract(stats_json(&s.stats.snapshot()))
}
fn stats_json(snap: &crate::stats::StatsSnapshot) -> serde_json::Value {
let pct_blocked = if snap.total > 0 {
(snap.blocked as f64 / snap.total as f64 * 1000.0).round() / 10.0
} else { 0.0 };
serde_json::json!({
"total": snap.total,
"blocked": snap.blocked,
"forwarded": snap.forwarded,
"nxdomain": snap.nxdomain,
"refused": snap.refused,
"servfail": snap.servfail,
"local_hits": snap.local_hits,
"blocked_percent": pct_blocked,
"uptime_secs": snap.uptime_secs,
"qps_1m": snap.qps_1m,
"qps_5m": snap.qps_5m,
"qps_peak": snap.qps_peak,
"latency_p50_ms": snap.latency_p50_ms,
"latency_p95_ms": snap.latency_p95_ms,
"latency_p99_ms": snap.latency_p99_ms,
"cache_hit_rate": snap.cache_hit_rate,
"cache_entries": snap.cache_entries,
"dnssec": {
"secure": snap.dnssec_secure,
"bogus": snap.dnssec_bogus,
"insecure": snap.dnssec_insecure,
},
})
}
async fn stats_stream_handler(
State(s): State<AppState>,
) -> Sse<impl stream::Stream<Item = Result<Event, Infallible>>> {
let sse_stream = stream::unfold(s.stats, |stats| async move {
tokio::time::sleep(Duration::from_secs(1)).await;
let data = stats_json(&stats.snapshot()).to_string();
let event = Event::default().data(data);
Some((Ok::<Event, Infallible>(event), stats))
});
Sse::new(sse_stream).keep_alive(KeepAlive::default())
}
async fn config_handler(State(s): State<AppState>) -> impl IntoResponse {
let cfg = s.cfg.as_ref();
let api_dns = store::load().map(|st| st.entries.len()).unwrap_or(0);
let api_bl = store::load_blacklist().map(|bl| bl.entries.len()).unwrap_or(0);
let api_feeds = crate::feeds::load_feeds().map(|f| f.feeds.len()).unwrap_or(0);
JsonExtract(serde_json::json!({
"port": cfg.port,
"interfaces": cfg.interfaces,
"forward_zones": cfg.forward_zones.iter().map(|fz| serde_json::json!({
"name": fz.name,
"addrs": fz.addrs,
"tls": fz.tls,
})).collect::<Vec<_>>(),
"file_local_zones": cfg.local_zones.len(),
"file_local_data": cfg.local_data.len(),
"api_dns_entries": api_dns,
"api_blacklist": api_bl,
"api_feeds": api_feeds,
"access_control": cfg.access_control,
"private_addresses": cfg.private_addresses,
"rate_limit": cfg.rate_limit,
"cache_max_ttl": cfg.cache_max_ttl,
"dnssec_validation": cfg.dnssec_validation,
"log_retention": cfg.log_retention,
"log_client_ip": cfg.log_client_ip,
"api_port": cfg.api_port,
"logfile": cfg.logfile,
"hsm": serde_json::json!({
"active": crate::hsm::is_active(),
"pkcs11_lib": cfg.hsm_pkcs11_lib,
"slot": cfg.hsm_slot,
"pin": cfg.hsm_pin.as_ref().map(|_| "***"),
"api_key_label": cfg.hsm_api_key_label,
"store_key_label": cfg.hsm_store_key_label,
}),
}))
}
async fn reload_handler(State(s): State<AppState>) -> impl IntoResponse {
match crate::config::load(&s.cfg_path) {
Ok(new_cfg) => {
let new_zones = crate::build_zone_set(&new_cfg);
s.zones.store(std::sync::Arc::new(new_zones));
info!(cfg_path = %s.cfg_path, "API hot-reload complete");
s.audit.send(AuditEvent::ConfigReload);
(StatusCode::OK, JsonExtract(serde_json::json!({
"status": "ok",
"cfg_path": s.cfg_path,
"local_zones": new_cfg.local_zones.len(),
"local_data": new_cfg.local_data.len(),
})))
}
Err(e) => {
warn!(err = %e, "API reload failed — keeping current zones");
(StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({
"error": "RELOAD_FAILED",
"details": e.to_string(),
})))
}
}
}
async fn list_dns_handler(State(_s): State<AppState>) -> impl IntoResponse {
match store::load() {
Ok(st) => (StatusCode::OK, JsonExtract(serde_json::json!({
"entries": st.entries,
"total": st.entries.len()
}))),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({
"error": e.to_string()
}))),
}
}
async fn add_dns_handler(
State(s): State<AppState>,
ApiJson(req): ApiJson<AddDnsRequest>,
) -> impl IntoResponse {
if let Err(e) = validate_dns_name(&req.name) {
return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_NAME", "details": e
})));
}
for (field, val) in [
("value", req.value.as_deref().unwrap_or("")),
("tag", req.tag.as_deref().unwrap_or("")),
("description", req.description.as_deref().unwrap_or("")),
("fingerprint", req.fingerprint.as_deref().unwrap_or("")),
("cert_data", req.cert_data.as_deref().unwrap_or("")),
("services", req.services.as_deref().unwrap_or("")),
("regexp", req.regexp.as_deref().unwrap_or("")),
("replacement", req.replacement.as_deref().unwrap_or("")),
("flags_naptr", req.flags_naptr.as_deref().unwrap_or("")),
] {
if let Err(e) = validate_no_control_chars(val, field) {
return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_FIELD", "details": e
})));
}
}
match req.entry_type {
DnsType::CNAME | DnsType::NS | DnsType::PTR | DnsType::MX | DnsType::SRV => {
if let Some(ref v) = req.value {
if let Err(e) = validate_dns_name(v) {
return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_VALUE", "details": e
})));
}
}
}
DnsType::NAPTR => {
if let Some(ref r) = req.replacement {
if r != "." {
if let Err(e) = validate_dns_name(r) {
return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_REPLACEMENT", "details": e
})));
}
}
}
}
_ => {}
}
const RFC2181_MAX_TTL: i64 = 2_147_483_647;
if req.ttl < 0 || req.ttl > RFC2181_MAX_TTL {
return (StatusCode::UNPROCESSABLE_ENTITY, JsonExtract(serde_json::json!({
"error": "INVALID_TTL",
"details": "TTL must be between 0 and 2147483647"
})));
}
let ttl = req.ttl as u32;
let entry = DnsEntry {
id: DnsEntry::new_id(),
name: ensure_dot(&req.name),
entry_type: req.entry_type,
ttl: ttl.min(MAX_API_TTL),
value: req.value,
priority: req.priority,
weight: req.weight,
port: req.port,
flags: req.flags,
tag: req.tag,
order: req.order,
preference_naptr: req.preference_naptr,
flags_naptr: req.flags_naptr,
services: req.services,
regexp: req.regexp,
replacement: req.replacement,
algorithm: req.algorithm,
fp_type: req.fp_type,
fingerprint: req.fingerprint,
cert_usage: req.cert_usage,
selector: req.selector,
matching_type: req.matching_type,
cert_data: req.cert_data,
description: req.description,
};
let rr = match entry.to_rr_string() {
Some(r) => r,
None => return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_ENTRY",
"details": "Missing required fields for this record type"
}))),
};
let record = match parse_local_data(&rr) {
Some(r) => r,
None => return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "PARSE_FAILED",
"details": format!("Could not parse RR: {rr}")
}))),
};
{
let _guard = s.zones_mutex.lock().await;
let mut st = store::load().unwrap_or_default();
if st.entries.len() >= MAX_DNS_ENTRIES {
return (StatusCode::UNPROCESSABLE_ENTITY, JsonExtract(serde_json::json!({
"error": "LIMIT_EXCEEDED",
"details": format!("Maximum {} DNS entries reached", MAX_DNS_ENTRIES)
})));
}
st.entries.push(entry.clone());
if let Err(e) = store::save(&st) {
return (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({
"error": e.to_string()
})));
}
let current = s.zones.load_full();
let mut new_zones = (*current).clone();
let name = record.name.clone();
new_zones.zones.entry(name.clone()).or_insert(ZoneAction::Static);
new_zones.records.entry(name).or_default().push(record);
s.zones.store(Arc::new(new_zones));
}
info!(id=%entry.id, name=%entry.name, r#type=?entry.entry_type, "DNS entry added");
s.audit.send(AuditEvent::DnsAdd {
name: entry.name.clone(),
rtype: format!("{:?}", entry.entry_type),
value: entry.value.clone().unwrap_or_default(),
});
if let Some(ref j) = s.sync_journal {
j.push(SyncOp::AddDns { entry: entry.clone() });
}
(StatusCode::CREATED, JsonExtract(serde_json::json!({
"status": "ok",
"entry": entry,
"rr": rr
})))
}
async fn delete_dns_handler(
State(s): State<AppState>,
Path(id): Path<String>,
) -> impl IntoResponse {
let _guard = s.zones_mutex.lock().await;
let mut st = match store::load() {
Ok(s) => s,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({"error": e.to_string()}))),
};
let pos = st.entries.iter().position(|e| e.id == id);
let Some(pos) = pos else {
return (StatusCode::NOT_FOUND, JsonExtract(serde_json::json!({"error":"NOT_FOUND","id":id})));
};
let entry = st.entries.remove(pos);
if let Err(e) = store::save(&st) {
return (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({"error": e.to_string()})));
}
if let Some(rr) = entry.to_rr_string() {
if let Some(record) = parse_local_data(&rr) {
let current = s.zones.load_full();
let mut new_zones = (*current).clone();
let name = record.name.clone();
if let Some(recs) = new_zones.records.get_mut(&name) {
let mut removed = false;
recs.retain(|r| {
if !removed && r == &record {
removed = true;
false
} else {
true
}
});
if recs.is_empty() {
new_zones.records.remove(&name);
new_zones.zones.remove(&name);
}
}
s.zones.store(Arc::new(new_zones));
}
}
info!(id=%id, "DNS entry deleted");
s.audit.send(AuditEvent::DnsDelete { id: id.clone() });
if let Some(ref j) = s.sync_journal {
j.push(SyncOp::DeleteDns { id: id.clone() });
}
(StatusCode::OK, JsonExtract(serde_json::json!({"status":"ok","deleted_id":id})))
}
async fn list_blacklist_handler(State(_s): State<AppState>) -> impl IntoResponse {
match store::load_blacklist() {
Ok(bl) => (StatusCode::OK, JsonExtract(serde_json::json!({
"blacklist": bl.entries,
"total": bl.entries.len()
}))),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({
"error": e.to_string()
}))),
}
}
async fn add_blacklist_handler(
State(s): State<AppState>,
ApiJson(req): ApiJson<AddBlacklistRequest>,
) -> impl IntoResponse {
if let Err(e) = validate_dns_name(&req.domain) {
return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_NAME", "details": e
})));
}
if let Some(ref desc) = req.description {
if let Err(e) = validate_no_control_chars(desc, "description") {
return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_FIELD", "details": e
})));
}
}
let entry = {
let _guard = s.zones_mutex.lock().await;
let mut bl = store::load_blacklist().unwrap_or_default();
if bl.entries.len() >= MAX_BLACKLIST_ENTRIES {
return (StatusCode::UNPROCESSABLE_ENTITY, JsonExtract(serde_json::json!({
"error": "LIMIT_EXCEEDED",
"details": format!("Maximum {} blacklist entries reached", MAX_BLACKLIST_ENTRIES)
})));
}
let entry = BlacklistEntry {
id: uuid::Uuid::new_v4().to_string(),
domain: req.domain.clone(),
action: req.action.clone(),
description: req.description.clone(),
};
bl.entries.push(entry.clone());
if let Err(e) = store::save_blacklist(&bl) {
return (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({
"error": e.to_string()
})));
}
let current = s.zones.load_full();
let mut new_zones = (*current).clone();
new_zones.override_zone(&req.domain, ZoneAction::from(&req.action));
s.zones.store(Arc::new(new_zones));
entry
};
info!(domain=%req.domain, action=?req.action, "Blacklist entry added");
s.audit.send(AuditEvent::BlacklistAdd { domain: entry.domain.clone() });
if let Some(ref j) = s.sync_journal {
j.push(SyncOp::AddBlacklist { entry: entry.clone() });
}
(StatusCode::CREATED, JsonExtract(serde_json::json!({
"status": "ok",
"entry": entry
})))
}
async fn delete_blacklist_handler(
State(s): State<AppState>,
Path(id): Path<String>,
) -> impl IntoResponse {
let _guard = s.zones_mutex.lock().await;
let mut bl = match store::load_blacklist() {
Ok(b) => b,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({"error": e.to_string()}))),
};
let pos = bl.entries.iter().position(|e| e.id == id);
let Some(pos) = pos else {
return (StatusCode::NOT_FOUND, JsonExtract(serde_json::json!({"error":"NOT_FOUND","id":id})));
};
let removed = bl.entries.remove(pos);
if let Err(e) = store::save_blacklist(&bl) {
return (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({"error": e.to_string()})));
}
let current = s.zones.load_full();
let mut new_zones = (*current).clone();
new_zones.remove_zone(&removed.domain);
s.zones.store(Arc::new(new_zones));
info!(id=%id, domain=%removed.domain, "Blacklist entry deleted");
s.audit.send(AuditEvent::BlacklistDelete { id: id.clone() });
if let Some(ref j) = s.sync_journal {
j.push(SyncOp::DeleteBlacklist { id: id.clone() });
}
(StatusCode::OK, JsonExtract(serde_json::json!({"status":"ok","deleted_id":id,"domain":removed.domain})))
}
async fn get_feeds_handler(State(_s): State<AppState>) -> impl IntoResponse {
let config = feeds::load_feeds().unwrap_or_default();
(StatusCode::OK, JsonExtract(serde_json::json!({"feeds": config.feeds, "total": config.feeds.len()})))
}
async fn add_feed_handler(
State(s): State<AppState>,
JsonExtract(p): JsonExtract<AddFeedRequest>,
) -> impl IntoResponse {
let current = feeds::load_feeds().unwrap_or_default();
if current.feeds.len() >= MAX_FEEDS {
return (StatusCode::UNPROCESSABLE_ENTITY, JsonExtract(serde_json::json!({
"error": "LIMIT_EXCEEDED",
"details": format!("Maximum {} feed subscriptions reached", MAX_FEEDS)
})));
}
match add_feed(p.name, p.url, p.format, p.action, p.description).await {
Ok(feed) => {
info!("Feed added: {} ({})", feed.name, feed.url);
s.audit.send(AuditEvent::FeedAdd {
id: feed.id.clone(),
name: feed.name.clone(),
url: feed.url.clone(),
});
if let Some(ref j) = s.sync_journal {
j.push(SyncOp::AddFeed { feed: feed.clone() });
}
(StatusCode::CREATED, JsonExtract(serde_json::json!({
"status": "ok", "feed": feed,
"message": "Run POST /feeds/:id/update to fetch domains."
})))
}
Err(e) => {
let code = StatusCode::from_u16(e.status_code()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(code, JsonExtract(serde_json::json!({
"error": "FEED_ERROR", "details": e.to_string()
})))
}
}
}
async fn delete_feed_handler(
State(s): State<AppState>,
Path(id): Path<String>,
) -> impl IntoResponse {
match remove_feed(&id) {
Ok(()) => {
s.audit.send(AuditEvent::FeedDelete { id: id.clone() });
if let Some(ref j) = s.sync_journal {
j.push(SyncOp::DeleteFeed { id: id.clone() });
}
(StatusCode::OK, JsonExtract(serde_json::json!({"status":"ok","deleted_id":id})))
}
Err(crate::error::AppError::BadRequest(msg)) => (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({"error":"BAD_REQUEST","details":msg}))),
Err(e) => (StatusCode::NOT_FOUND, JsonExtract(serde_json::json!({"error":"FEED_NOT_FOUND","details":e.to_string()}))),
}
}
async fn update_feeds_handler(State(s): State<AppState>) -> impl IntoResponse {
match update_all_feeds().await {
Ok(results) => {
let updated = results.iter().filter(|r| r.status == "updated").count();
let errors = results.iter().filter(|r| r.status == "error").count();
let new_zones = crate::build_zone_set(&s.cfg);
s.zones.store(std::sync::Arc::new(new_zones));
info!(updated, errors, "Feed update complete — zones rebuilt");
(StatusCode::OK, JsonExtract(serde_json::json!({
"status": "ok", "results": results,
"summary": {"updated": updated, "errors": errors}
})))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({"error":e.to_string()}))),
}
}
async fn update_one_feed_handler(
State(s): State<AppState>,
Path(id): Path<String>,
) -> impl IntoResponse {
let feed_url = feeds::load_feeds()
.ok()
.and_then(|cfg| cfg.feeds.into_iter().find(|f| f.id == id))
.map(|f| f.url);
match update_one_feed(&id).await {
Ok(result) => {
let new_zones = crate::build_zone_set(&s.cfg);
s.zones.store(std::sync::Arc::new(new_zones));
if result.error.is_none() {
if let (Some(j), Some(url)) = (s.sync_journal.as_ref(), feed_url) {
j.push(SyncOp::UpdateFeed { id: id.clone(), url });
}
}
let code = if result.error.is_some() { StatusCode::INTERNAL_SERVER_ERROR } else { StatusCode::OK };
(code, JsonExtract(serde_json::json!({"result": result})))
}
Err(crate::error::AppError::BadRequest(msg)) => (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({"error":"BAD_REQUEST","details":msg}))),
Err(e) => (StatusCode::NOT_FOUND, JsonExtract(serde_json::json!({"error":e.to_string()}))),
}
}
async fn feed_presets_handler() -> impl IntoResponse {
let presets = builtin_presets();
JsonExtract(serde_json::json!({"presets": presets, "total": presets.len()}))
}
async fn upstreams_handler(State(s): State<AppState>) -> impl IntoResponse {
let statuses = match s.upstreams.read() {
Ok(g) => g.clone(),
Err(e) => {
error!(err = %e, "upstreams RwLock poisoned");
return (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({
"error": "INTERNAL", "details": "upstream state unavailable"
}))).into_response();
}
};
let total = statuses.len();
let healthy = statuses.iter().filter(|u| u.healthy).count();
(StatusCode::OK, JsonExtract(serde_json::json!({
"upstreams": statuses,
"total": total,
"healthy": healthy,
}))).into_response()
}
const LOG_LIMIT_MAX: usize = 1_000;
const LOG_LIMIT_DEFAULT: usize = 100;
#[derive(Deserialize)]
struct LogsParams {
#[serde(default = "default_log_limit")]
limit: usize,
#[serde(default)]
page: usize,
action: Option<String>,
client: Option<String>,
since: Option<u64>,
}
fn default_log_limit() -> usize { LOG_LIMIT_DEFAULT }
async fn logs_handler(
State(s): State<AppState>,
params_result: Result<Query<LogsParams>, QueryRejection>,
) -> Response {
let Query(params) = match params_result {
Ok(q) => q,
Err(e) => return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_PARAM",
"details": e.to_string()
}))).into_response(),
};
if params.limit > LOG_LIMIT_MAX {
return (StatusCode::UNPROCESSABLE_ENTITY, JsonExtract(serde_json::json!({
"error": "INVALID_PARAM",
"details": format!("limit must be ≤ {}", LOG_LIMIT_MAX),
}))).into_response();
}
let action = match params.action.as_deref() {
Some(s) => match LogAction::from_str(s) {
Some(a) => Some(a),
None => return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_PARAM",
"details": format!("action '{}' is not valid — expected one of: forwarded, cached, local, blocked, nxdomain, refused, servfail", s),
}))).into_response(),
},
None => None,
};
let client = match params.client.as_deref() {
Some(s) => match s.parse::<std::net::IpAddr>() {
Ok(ip) => Some(ip),
Err(_) => return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_PARAM",
"details": format!("client '{}' is not a valid IP address", s),
}))).into_response(),
},
None => None,
};
let q = LogQuery {
limit: params.limit,
page: params.page,
action,
client,
since_secs: params.since,
};
let (entries, total) = match s.log_buffer.lock() {
Ok(buf) => buf.query(&q),
Err(e) => {
error!(err = %e, "log_buffer Mutex poisoned");
return (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({
"error": "INTERNAL", "details": "log buffer unavailable"
}))).into_response();
}
};
JsonExtract(serde_json::json!({
"entries": entries,
"total": total,
"page": params.page,
"limit": params.limit,
})).into_response()
}
async fn clear_logs_handler(
State(s): State<AppState>,
) -> impl IntoResponse {
let deleted = match s.log_buffer.lock() {
Ok(mut buf) => buf.clear(),
Err(e) => {
error!(err = %e, "log_buffer Mutex poisoned");
return (StatusCode::INTERNAL_SERVER_ERROR, JsonExtract(serde_json::json!({
"error": "INTERNAL", "details": "log buffer unavailable"
}))).into_response();
}
};
s.audit.send(AuditEvent::LogsClear { count: deleted });
info!(entries_deleted = deleted, "log buffer cleared via DELETE /logs");
JsonExtract(serde_json::json!({
"message": "log buffer cleared",
"entries_deleted": deleted,
})).into_response()
}
async fn tls_status_handler(State(s): State<AppState>) -> impl IntoResponse {
let tls = s.tls_cfg.as_ref();
JsonExtract(serde_json::json!({
"dot": {
"enabled": tls.cert_path.is_some() && tls.key_path.is_some(),
"port": tls.dot_port.unwrap_or(853),
"rfc": "RFC 7858"
},
"doh": {
"enabled": tls.cert_path.is_some() && tls.key_path.is_some(),
"port": tls.doh_port.unwrap_or(443),
"rfc": "RFC 8484"
},
"doq": {
"enabled": tls.cert_path.is_some() && tls.key_path.is_some(),
"port": tls.doq_port.unwrap_or(853),
"rfc": "RFC 9250"
},
"cert": tls.cert_path.as_deref().unwrap_or("not configured"),
"hostname": tls.hostname.as_deref().unwrap_or("runbound.local")
}))
}
#[derive(Deserialize)]
struct AuditTailQuery { n: Option<usize> }
async fn audit_tail_handler(
State(s): State<AppState>,
Query(q): Query<AuditTailQuery>,
) -> impl IntoResponse {
let n = q.n.unwrap_or(100).min(1000);
let log_path = s.base_dir.join("audit.log");
match crate::audit::tail_audit_log(&log_path, n) {
Ok(lines) => (StatusCode::OK, JsonExtract(serde_json::json!({
"lines": lines,
"count": lines.len(),
}))),
Err(e) => (StatusCode::NOT_FOUND, JsonExtract(serde_json::json!({
"error": "AUDIT_LOG_UNAVAILABLE",
"details": e,
}))),
}
}
async fn metrics_handler(State(s): State<AppState>) -> impl IntoResponse {
let snap = s.stats.snapshot();
let body = format!(
"# HELP runbound_queries_total Total DNS queries received
\
# TYPE runbound_queries_total counter
\
runbound_queries_total {total}
\
# HELP runbound_blocked_total Queries answered with REFUSED (blacklist/feeds)
\
# TYPE runbound_blocked_total counter
\
runbound_blocked_total {blocked}
\
# HELP runbound_nxdomain_total Queries answered with NXDOMAIN
\
# TYPE runbound_nxdomain_total counter
\
runbound_nxdomain_total {nxdomain}
\
# HELP runbound_refused_total Queries answered with REFUSED (ACL/rate limit)
\
# TYPE runbound_refused_total counter
\
runbound_refused_total {refused}
\
# HELP runbound_servfail_total Queries answered with SERVFAIL
\
# TYPE runbound_servfail_total counter
\
runbound_servfail_total {servfail}
\
# HELP runbound_forwarded_total Queries forwarded to upstream resolvers
\
# TYPE runbound_forwarded_total counter
\
runbound_forwarded_total {forwarded}
\
# HELP runbound_local_hits_total Queries answered from local zone data
\
# TYPE runbound_local_hits_total counter
\
runbound_local_hits_total {local_hits}
\
# HELP runbound_uptime_seconds Process uptime in seconds
\
# TYPE runbound_uptime_seconds gauge
\
runbound_uptime_seconds {uptime}
\
# HELP runbound_qps Queries per second
\
# TYPE runbound_qps gauge
\
runbound_qps{{window=\"1m\"}} {qps_1m}
\
runbound_qps{{window=\"5m\"}} {qps_5m}
\
runbound_qps{{window=\"peak\"}} {qps_peak}
\
# HELP runbound_latency_ms DNS query latency percentiles in milliseconds
\
# TYPE runbound_latency_ms gauge
\
runbound_latency_ms{{quantile=\"0.5\"}} {p50}
\
runbound_latency_ms{{quantile=\"0.95\"}} {p95}
\
runbound_latency_ms{{quantile=\"0.99\"}} {p99}
\
# HELP runbound_cache_hit_rate Cache hit rate percentage (0–100)
\
# TYPE runbound_cache_hit_rate gauge
\
runbound_cache_hit_rate {cache_hit_rate}
\
# HELP runbound_cache_entries Approximate cached DNS entries
\
# TYPE runbound_cache_entries gauge
\
runbound_cache_entries {cache_entries}
\
# HELP runbound_dnssec_total DNSSEC validation results
\
# TYPE runbound_dnssec_total counter
\
runbound_dnssec_total{{status=\"secure\"}} {dnssec_secure}
\
runbound_dnssec_total{{status=\"bogus\"}} {dnssec_bogus}
\
runbound_dnssec_total{{status=\"insecure\"}} {dnssec_insecure}
",
total = snap.total,
blocked = snap.blocked,
nxdomain = snap.nxdomain,
refused = snap.refused,
servfail = snap.servfail,
forwarded = snap.forwarded,
local_hits = snap.local_hits,
uptime = snap.uptime_secs,
qps_1m = snap.qps_1m,
qps_5m = snap.qps_5m,
qps_peak = snap.qps_peak,
p50 = snap.latency_p50_ms,
p95 = snap.latency_p95_ms,
p99 = snap.latency_p99_ms,
cache_hit_rate = snap.cache_hit_rate,
cache_entries = snap.cache_entries,
dnssec_secure = snap.dnssec_secure,
dnssec_bogus = snap.dnssec_bogus,
dnssec_insecure = snap.dnssec_insecure,
);
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "text/plain; version=0.0.4; charset=utf-8")],
body,
)
}
#[derive(Deserialize)]
struct RotateKeyRequest {
new_key: String,
}
async fn rotate_key_handler(
State(s): State<AppState>,
ApiJson(req): ApiJson<RotateKeyRequest>,
) -> impl IntoResponse {
if req.new_key.len() < 32 {
return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "WEAK_KEY",
"details": "new_key must be at least 32 characters",
}))).into_response();
}
if req.new_key.bytes().any(|b| b < 0x20 || b == 0x7f) {
return (StatusCode::BAD_REQUEST, JsonExtract(serde_json::json!({
"error": "INVALID_KEY",
"details": "new_key must not contain control characters",
}))).into_response();
}
rotate_api_key(req.new_key.clone());
let key_path = s.base_dir.join("api.key");
let persist_result = std::fs::write(&key_path, req.new_key.as_bytes()).and_then(|_| {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600))?;
}
Ok(())
});
if let Err(e) = persist_result {
warn!(path = %key_path.display(), err = %e, "Failed to persist rotated API key to disk");
}
s.audit.send(AuditEvent::ConfigReload);
info!("API key rotated via POST /rotate-key");
(StatusCode::OK, JsonExtract(serde_json::json!({
"status": "ok",
"message": "API key rotated — old token is immediately invalid",
}))).into_response()
}
fn ensure_dot(name: &str) -> String {
if name.ends_with('.') { name.to_string() } else { format!("{}.", name) }
}
fn validate_no_control_chars(s: &str, field: &'static str) -> Result<(), String> {
if s.bytes().any(|b| b < 0x20 || b == 0x7f) {
return Err(format!("Field '{}' must not contain control characters (\r, \n, etc.)", field));
}
Ok(())
}
fn validate_dns_name(name: &str) -> Result<(), &'static str> {
let n = name.trim_end_matches('.');
if n.is_empty() {
return Err("Domain name cannot be empty or the root zone");
}
if n.len() > 253 {
return Err("Domain name exceeds 253 characters");
}
for label in n.split('.') {
if label.is_empty() {
return Err("Domain label cannot be empty (no consecutive or leading dots)");
}
if label.len() > 63 {
return Err("Domain label exceeds 63 characters");
}
if label.starts_with('-') || label.ends_with('-') {
return Err("Domain label cannot start or end with a hyphen");
}
if !label.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_') {
return Err("Domain label contains invalid characters \
(ASCII alphanumeric, hyphens, underscores only)");
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use http_body_util::BodyExt;
use tower::ServiceExt;
const TEST_KEY: &str = "test-api-key-for-unit-tests";
fn make_test_app() -> Router {
init_api_key(Some(TEST_KEY.to_string()));
let _ = crate::runtime::BASE_DIR.set(std::path::PathBuf::from("/tmp/runbound-test"));
let zones = Arc::new(ArcSwap::new(Arc::new(
crate::dns::local::LocalZoneSet::default()
)));
let cfg_arc = Arc::new(crate::config::parser::UnboundConfig::default());
let log_buffer = crate::logbuffer::new_shared(1000, true);
let upstreams = crate::upstreams::init_upstreams(&cfg_arc);
let state = AppState {
zones: Arc::clone(&zones),
zones_mutex: Arc::new(tokio::sync::Mutex::new(())),
tls_cfg: Arc::new(crate::config::parser::TlsConfig::default()),
rate_limiter: ApiRateLimiter::new_public(),
stats: crate::stats::Stats::new(),
cfg: Arc::clone(&cfg_arc),
cfg_path: "/dev/null".to_string(),
log_buffer,
upstreams,
sync_journal: None,
slave_mode: false,
base_dir: Arc::new(std::path::PathBuf::from("/tmp/runbound-test")),
audit: crate::audit::init(false, None, None, std::path::PathBuf::from("/tmp")),
};
router(state)
}
async fn body_json(body: axum::body::Body) -> serde_json::Value {
let bytes = body.collect().await.unwrap().to_bytes();
serde_json::from_slice(&bytes).unwrap_or(serde_json::Value::Null)
}
fn auth_header() -> (&'static str, String) {
("Authorization", format!("Bearer {}", TEST_KEY))
}
#[tokio::test]
async fn stats_requires_auth() {
let app = make_test_app();
let resp = app.oneshot(
Request::builder().uri("/stats").body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn stats_schema() {
let app = make_test_app();
let (k, v) = auth_header();
let resp = app.oneshot(
Request::builder().uri("/stats").header(k, v).body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let json = body_json(resp.into_body()).await;
for field in &["total", "blocked", "forwarded", "qps_1m", "qps_5m",
"latency_p50_ms", "cache_hit_rate", "local_hits"] {
assert!(json.get(field).is_some(), "missing field: {field}");
}
}
#[tokio::test]
async fn stats_stream_requires_auth() {
let app = make_test_app();
let resp = app.oneshot(
Request::builder().uri("/stats/stream").body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn stats_stream_content_type() {
let app = make_test_app();
let (k, v) = auth_header();
let resp = app.oneshot(
Request::builder().uri("/stats/stream").header(k, v).body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp.headers().get("content-type").and_then(|v| v.to_str().ok()).unwrap_or("");
assert!(ct.contains("text/event-stream"), "unexpected Content-Type: {ct}");
}
#[tokio::test]
async fn upstreams_requires_auth() {
let app = make_test_app();
let resp = app.oneshot(
Request::builder().uri("/upstreams").body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn upstreams_schema() {
let app = make_test_app();
let (k, v) = auth_header();
let resp = app.oneshot(
Request::builder().uri("/upstreams").header(k, v).body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let json = body_json(resp.into_body()).await;
assert!(json.get("upstreams").is_some());
assert!(json.get("total").is_some());
assert!(json.get("healthy").is_some());
}
#[tokio::test]
async fn logs_requires_auth() {
let app = make_test_app();
let resp = app.oneshot(
Request::builder().uri("/logs").body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn logs_schema() {
let app = make_test_app();
let (k, v) = auth_header();
let resp = app.oneshot(
Request::builder().uri("/logs").header(k, v).body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let json = body_json(resp.into_body()).await;
assert!(json.get("entries").is_some());
assert!(json.get("total").is_some());
}
#[tokio::test]
async fn logs_limit_too_large() {
let app = make_test_app();
let (k, v) = auth_header();
let resp = app.oneshot(
Request::builder().uri("/logs?limit=2000").header(k, v).body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[tokio::test]
async fn logs_invalid_action() {
let app = make_test_app();
let (k, v) = auth_header();
let resp = app.oneshot(
Request::builder().uri("/logs?action=invalid").header(k, v).body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn logs_invalid_client_ip() {
let app = make_test_app();
let (k, v) = auth_header();
let resp = app.oneshot(
Request::builder().uri("/logs?client=notanip").header(k, v).body(Body::empty()).unwrap()
).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_validate_dns_name_253_chars_accepted() {
let name = format!("{}.{}.{}.{}",
"a".repeat(63), "b".repeat(63), "c".repeat(63), "d".repeat(61));
assert_eq!(name.len(), 253);
assert!(validate_dns_name(&name).is_ok());
}
#[test]
fn test_validate_dns_name_254_chars_rejected() {
let name = format!("{}.{}.{}.{}",
"a".repeat(63), "b".repeat(63), "c".repeat(63), "d".repeat(62));
assert_eq!(name.len(), 254);
assert!(validate_dns_name(&name).is_err());
}
#[test]
fn test_validate_dns_name_253_with_trailing_dot_accepted() {
let name = format!("{}.{}.{}.{}.",
"a".repeat(63), "b".repeat(63), "c".repeat(63), "d".repeat(61));
assert_eq!(name.trim_end_matches('.').len(), 253);
assert!(validate_dns_name(&name).is_ok());
}
#[test]
fn test_validate_dns_name_254_with_trailing_dot_rejected() {
let name = format!("{}.{}.{}.{}.",
"a".repeat(63), "b".repeat(63), "c".repeat(63), "d".repeat(62));
assert_eq!(name.trim_end_matches('.').len(), 254);
assert!(validate_dns_name(&name).is_err());
}
#[test]
fn test_validate_dns_name_label_64_chars_rejected() {
let name = "a".repeat(64);
assert!(validate_dns_name(&name).is_err());
}
#[test]
fn test_validate_dns_name_label_63_chars_accepted() {
let name = "a".repeat(63);
assert!(validate_dns_name(&name).is_ok());
}
}