use crate::access_log::AccessLogger;
use crate::cert::acme::ChallengeMap;
use crate::auth::Authenticator;
use crate::error::ErrorPages;
use crate::geoip;
use crate::metrics::Metrics;
use crate::router::Router;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinSet;
mod http;
pub use http::{run_plain, run_tls};
mod quic;
pub use quic::run_quic;
mod stream;
pub use stream::run_stream_proxy;
mod datagram;
pub use datagram::run_dgram_proxy;
mod service;
pub(super) use service::{FirstRequest, HypershuntService};
mod socket;
pub use socket::{BoundSocket, LocalAddr, LocalUnixPath, bind_socket};
#[allow(unused_imports)]
pub use socket::bind_tcp_socket;
#[allow(unused_imports)]
pub(crate) use socket::{IncomingStream, PeerAddr, apply_proxy_proto};
pub(super) const DRAIN_TIMEOUT: Duration = Duration::from_secs(30);
pub(super) const TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
pub(super) const DEFAULT_HEADER_TIMEOUT_SECS: u64 = 30;
const ACCEPT_ERROR_BACKOFF: Duration = Duration::from_millis(100);
async fn backoff_after_accept_error(bind: &str, e: &std::io::Error) {
tracing::error!(bind = %bind, "accept error: {e}");
#[cfg(unix)]
if matches!(
e.raw_os_error(),
Some(libc::EMFILE | libc::ENFILE | libc::ENOBUFS | libc::ENOMEM)
) {
tokio::time::sleep(ACCEPT_ERROR_BACKOFF).await;
}
}
pub type SharedAppState = Arc<arc_swap::ArcSwap<AppState>>;
pub struct AppState {
pub router: Arc<Router>,
pub acme_challenges: ChallengeMap,
pub authenticator: Arc<dyn Authenticator>,
pub metrics: Arc<Metrics>,
pub geoip: Option<Arc<geoip::CountryReader>>,
pub health: Arc<crate::handler::health::HealthState>,
pub error_pages: Arc<ErrorPages>,
pub jwt_manager: Option<Arc<crate::jwt::JwtManager>>,
pub oidc: Option<Arc<crate::oidc::OidcProvider>>,
pub access_log: Arc<AccessLogger>,
pub cache: Option<Arc<crate::cache::CacheStore>>,
}
pub(super) async fn drain_connections(
name: &str,
mut connections: JoinSet<()>,
metrics: &Metrics,
) {
use std::sync::atomic::Ordering::Relaxed;
let n = connections.len();
if n > 0 {
tracing::info!(bind = %name, connections = n, "draining");
}
let drain = async { while connections.join_next().await.is_some() {} };
if tokio::time::timeout(DRAIN_TIMEOUT, drain).await.is_err() {
let abandoned = connections.len();
metrics
.shutdown_abandoned_total
.fetch_add(abandoned as u64, Relaxed);
metrics
.shutdown_drained_total
.fetch_add((n - abandoned) as u64, Relaxed);
tracing::warn!(
bind = %name,
"drain timeout after {}s; {} connection(s) abandoned",
DRAIN_TIMEOUT.as_secs(),
abandoned,
);
} else {
metrics.shutdown_drained_total.fetch_add(n as u64, Relaxed);
}
}
#[cfg(test)]
mod tests;