use std::fmt;
pub mod client;
pub mod forward;
pub mod health;
pub mod pool;
pub use client::{UpstreamBackground, UpstreamClient};
pub use forward::{DEFAULT_QUERY_TIMEOUT, ForwardResult};
pub use health::{UpstreamHealth, UpstreamHealthRow};
pub use pool::{
DEFAULT_FAILOVER_BUDGET, LatencyWeightedSelector, RandomSelector, SharedUpstreamPool,
UpstreamObservation, UpstreamPool, UpstreamSelector,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UpstreamTransport {
Udp,
Tcp,
Dot,
Doh,
}
impl fmt::Display for UpstreamTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Udp => f.write_str("UDP"),
Self::Tcp => f.write_str("TCP"),
Self::Dot => f.write_str("DoT"),
Self::Doh => f.write_str("DoH"),
}
}
}
#[derive(Debug, Clone)]
pub struct UpstreamConfig {
pub addr: std::net::SocketAddr,
pub transport: UpstreamTransport,
pub tls_server_name: Option<String>,
pub http_endpoint: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("failed to connect upstream {transport} transport: {source}")]
Connect {
transport: UpstreamTransport,
#[source]
source: hickory_net::NetError,
},
#[error("invalid or missing TLS server name: {0}")]
InvalidServerName(String),
#[error("upstream transport error: {0}")]
Transport(String),
#[error("upstream {transport} query timed out")]
Timeout { transport: UpstreamTransport },
#[error("upstream {transport} exchange failed: {source}")]
Exchange {
transport: UpstreamTransport,
#[source]
source: hickory_net::NetError,
},
#[error("all upstreams failed after {attempts} attempt(s)")]
AllUpstreamsFailed { attempts: usize },
}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use super::*;
#[test]
fn transport_display() {
assert_eq!(UpstreamTransport::Udp.to_string(), "UDP");
assert_eq!(UpstreamTransport::Tcp.to_string(), "TCP");
assert_eq!(UpstreamTransport::Dot.to_string(), "DoT");
assert_eq!(UpstreamTransport::Doh.to_string(), "DoH");
}
#[test]
fn config_construction() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 853);
let cfg = UpstreamConfig {
addr,
transport: UpstreamTransport::Dot,
tls_server_name: Some("one.one.one.one".to_owned()),
http_endpoint: None,
};
assert_eq!(cfg.addr, addr);
assert_eq!(cfg.transport, UpstreamTransport::Dot);
assert_eq!(cfg.tls_server_name.as_deref(), Some("one.one.one.one"));
assert!(cfg.http_endpoint.is_none());
}
#[tokio::test]
async fn dot_missing_server_name_returns_error() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 853);
let cfg = UpstreamConfig {
addr,
transport: UpstreamTransport::Dot,
tls_server_name: None,
http_endpoint: None,
};
let result = UpstreamClient::connect(&cfg).await;
assert!(
matches!(result, Err(Error::InvalidServerName(_))),
"expected InvalidServerName error"
);
}
#[tokio::test]
async fn doh_missing_server_name_returns_error() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 443);
let cfg = UpstreamConfig {
addr,
transport: UpstreamTransport::Doh,
tls_server_name: None,
http_endpoint: None,
};
let result = UpstreamClient::connect(&cfg).await;
assert!(
matches!(result, Err(Error::InvalidServerName(_))),
"expected InvalidServerName error"
);
}
}