use askama::Template;
use askama_web::WebTemplate;
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Redirect, Response},
};
use serde::Deserialize;
use crate::{
resolver::upstream::UpstreamConfig,
storage::upstreams::{NewUpstream, Transport, Upstream, UpstreamRepository},
web::{
AppState, Chrome,
auth::CurrentUser,
render::{WebError, WebResult},
},
};
impl AppState {
pub(crate) async fn rebuild_upstream_pool(&self) -> WebResult<()> {
let rows = self.db.upstreams().list_enabled().await?;
let configs: Vec<_> = rows
.iter()
.filter_map(|r| UpstreamConfig::try_from(r).ok())
.collect();
let settings = self.resolver.settings();
let new_pool = crate::resolver::pipeline::engine::build_upstream_pool(
&configs,
&self.tracker,
settings.upstream_selection_strategy,
settings.upstream_parallel_fanout,
)
.await;
self.upstream_pool.store(new_pool);
Ok(())
}
async fn render_upstreams(
&self,
user: &CurrentUser,
error: Option<String>,
) -> WebResult<UpstreamsPageTemplate> {
let upstreams = self
.db
.upstreams()
.list()
.await?
.into_iter()
.map(|u| UpstreamView {
id: u.id,
address: u.address,
transport: u.transport.as_str(),
server_name: u.tls_server_name.unwrap_or_default(),
enabled: u.enabled,
})
.collect();
Ok(UpstreamsPageTemplate {
chrome: self.chrome("upstreams", user).await,
upstreams,
error,
})
}
pub async fn upstreams_page(
user: CurrentUser,
State(state): State<AppState>,
) -> WebResult<Response> {
Ok(state.render_upstreams(&user, None).await?.into_response())
}
pub async fn upstream_add(
user: CurrentUser,
State(state): State<AppState>,
axum::Form(form): axum::Form<AddUpstreamForm>,
) -> WebResult<Response> {
match state.add_upstream(form).await {
Ok(()) => Ok(Redirect::to("/upstreams").into_response()),
Err(WebError::BadRequest(msg)) => {
let page = state.render_upstreams(&user, Some(msg)).await?;
Ok((StatusCode::BAD_REQUEST, page).into_response())
}
Err(e) => Err(e),
}
}
async fn add_upstream(&self, form: AddUpstreamForm) -> WebResult<()> {
let transport: Transport = form
.transport
.parse()
.map_err(|_| WebError::bad_request("Transport must be one of udp, tcp, dot, doh."))?;
let address = form.address.trim().to_owned();
let server_name = match form.server_name.trim() {
"" => None,
s => Some(s.to_owned()),
};
if matches!(transport, Transport::Dot | Transport::Doh) && server_name.is_none() {
return Err(WebError::bad_request(
"A TLS server name is required for DoT and DoH upstreams.",
));
}
let probe = Upstream {
id: 0,
address: address.clone(),
transport,
tls_server_name: server_name.clone(),
enabled: true,
sort_order: 0,
};
if UpstreamConfig::try_from(&probe).is_err() {
return Err(WebError::bad_request(
"Address must be an IP address (optionally with :port); hostnames and DoH URLs are not supported in v0.1.",
));
}
self.db
.upstreams()
.insert(NewUpstream {
address,
transport,
tls_server_name: server_name,
enabled: true,
sort_order: 0,
})
.await?;
self.rebuild_upstream_pool().await
}
pub async fn upstream_remove(
_user: CurrentUser,
State(state): State<AppState>,
axum::Form(form): axum::Form<UpstreamIdForm>,
) -> WebResult<Response> {
state.db.upstreams().delete(form.id).await?;
state.rebuild_upstream_pool().await?;
Ok(Redirect::to("/upstreams").into_response())
}
pub async fn upstream_toggle(
_user: CurrentUser,
State(state): State<AppState>,
axum::Form(form): axum::Form<ToggleUpstreamForm>,
) -> WebResult<Response> {
state
.db
.upstreams()
.set_enabled(form.id, form.enabled)
.await?;
state.rebuild_upstream_pool().await?;
Ok(Redirect::to("/upstreams").into_response())
}
}
#[derive(Debug, Deserialize)]
pub struct AddUpstreamForm {
address: String,
transport: String,
#[serde(default)]
server_name: String,
}
#[derive(Debug, Deserialize)]
pub struct UpstreamIdForm {
id: i64,
}
#[derive(Debug, Deserialize)]
pub struct ToggleUpstreamForm {
id: i64,
enabled: bool,
}
struct UpstreamView {
id: i64,
address: String,
transport: &'static str,
server_name: String,
enabled: bool,
}
#[derive(Template, WebTemplate)]
#[template(path = "upstreams.html")]
struct UpstreamsPageTemplate {
chrome: Chrome,
upstreams: Vec<UpstreamView>,
error: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn state() -> (TempDir, AppState) {
let (dir, db) = crate::test_support::temp_db().await;
(dir, AppState::for_test(db).await)
}
fn form(address: &str, transport: &str, sni: &str) -> AddUpstreamForm {
AddUpstreamForm {
address: address.to_owned(),
transport: transport.to_owned(),
server_name: sni.to_owned(),
}
}
#[tokio::test]
async fn add_ip_upstream_persists_and_rebuilds() {
let (_d, st) = state().await;
let before = st.db.upstreams().list().await.unwrap().len();
st.add_upstream(form("9.9.9.9", "udp", ""))
.await
.expect("add");
let after = st.db.upstreams().list().await.unwrap();
assert_eq!(after.len(), before + 1);
assert!(after.iter().any(|u| u.address == "9.9.9.9"));
}
#[tokio::test]
async fn add_dot_without_sni_is_rejected() {
let (_d, st) = state().await;
let err = st
.add_upstream(form("1.1.1.1", "dot", ""))
.await
.unwrap_err();
assert!(matches!(err, WebError::BadRequest(_)));
}
#[tokio::test]
async fn add_hostname_upstream_is_rejected() {
let (_d, st) = state().await;
let err = st
.add_upstream(form("dns.quad9.net", "udp", ""))
.await
.unwrap_err();
assert!(matches!(err, WebError::BadRequest(_)));
}
#[tokio::test]
async fn bad_transport_is_rejected() {
let (_d, st) = state().await;
let err = st
.add_upstream(form("9.9.9.9", "grpc", ""))
.await
.unwrap_err();
assert!(matches!(err, WebError::BadRequest(_)));
}
}