use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use derivative::Derivative;
use futures::channel::oneshot;
use futures::prelude::*;
use itertools::Itertools;
use multimap::MultiMap;
use tokio::sync::mpsc;
use super::router::ApolloRouterError;
use crate::configuration::Configuration;
use crate::configuration::ListenAddr;
use crate::router_factory::Endpoint;
use crate::router_factory::RouterFactory;
use crate::uplink::license_enforcement::LicenseState;
pub(crate) trait HttpServerFactory {
type Future: Future<Output = Result<HttpServerHandle, ApolloRouterError>> + Send;
#[allow(clippy::too_many_arguments)]
fn create<RF>(
&self,
service_factory: RF,
configuration: Arc<Configuration>,
main_listener: Option<Listener>,
previous_listeners: ExtraListeners,
extra_endpoints: MultiMap<ListenAddr, Endpoint>,
license: LicenseState,
all_connections_stopped_sender: mpsc::Sender<()>,
) -> Self::Future
where
RF: RouterFactory;
fn live(&self, live: bool);
fn ready(&self, ready: bool);
}
type ExtraListeners = Vec<(ListenAddr, Listener)>;
#[derive(Derivative)]
#[derivative(Debug)]
pub(crate) struct HttpServerHandle {
main_shutdown_sender: oneshot::Sender<()>,
extra_shutdown_sender: oneshot::Sender<()>,
#[derivative(Debug = "ignore")]
main_future: Pin<Box<dyn Future<Output = Result<Listener, ApolloRouterError>> + Send>>,
#[derivative(Debug = "ignore")]
extra_futures: Pin<Box<dyn Future<Output = Result<ExtraListeners, ApolloRouterError>> + Send>>,
listen_addresses: Vec<ListenAddr>,
graphql_listen_address: Option<ListenAddr>,
all_connections_stopped_sender: mpsc::Sender<()>,
}
impl HttpServerHandle {
pub(crate) fn new(
main_shutdown_sender: oneshot::Sender<()>,
extra_shutdown_sender: oneshot::Sender<()>,
main_future: Pin<
Box<dyn Future<Output = Result<Listener, ApolloRouterError>> + Send + 'static>,
>,
extra_futures: Pin<
Box<dyn Future<Output = Result<ExtraListeners, ApolloRouterError>> + Send + 'static>,
>,
graphql_listen_address: Option<ListenAddr>,
listen_addresses: Vec<ListenAddr>,
all_connections_stopped_sender: mpsc::Sender<()>,
) -> Self {
Self {
main_shutdown_sender,
extra_shutdown_sender,
main_future,
extra_futures,
graphql_listen_address,
listen_addresses,
all_connections_stopped_sender,
}
}
pub(crate) async fn shutdown(mut self) -> Result<(), ApolloRouterError> {
#[cfg(unix)]
let listen_addresses = std::mem::take(&mut self.listen_addresses);
let (_main_listener, _extra_listener) = self.wait_for_servers().await?;
#[cfg(unix)]
for listen_address in listen_addresses {
if let ListenAddr::UnixSocket(path) = listen_address {
let _ = tokio::fs::remove_file(path).await;
}
}
Ok(())
}
pub(crate) async fn restart<RF, SF>(
self,
factory: &SF,
router: RF,
configuration: Arc<Configuration>,
web_endpoints: MultiMap<ListenAddr, Endpoint>,
license: LicenseState,
) -> Result<Self, ApolloRouterError>
where
SF: HttpServerFactory,
RF: RouterFactory,
{
let all_connections_stopped_sender = self.all_connections_stopped_sender.clone();
let (main_listener, extra_listeners) = self.wait_for_servers().await?;
tracing::debug!("previous server stopped");
let handle = factory
.create(
router,
configuration,
Some(main_listener),
extra_listeners,
web_endpoints,
license,
all_connections_stopped_sender,
)
.await?;
tracing::debug!(
"restarted on {}",
handle
.listen_addresses()
.iter()
.map(std::string::ToString::to_string)
.join(" - ")
);
Ok(handle)
}
pub(crate) fn listen_addresses(&self) -> &[ListenAddr] {
self.listen_addresses.as_slice()
}
pub(crate) fn graphql_listen_address(&self) -> &Option<ListenAddr> {
&self.graphql_listen_address
}
async fn wait_for_servers(self) -> Result<(Listener, ExtraListeners), ApolloRouterError> {
if let Err(_err) = self.main_shutdown_sender.send(()) {
tracing::error!("Failed to notify http thread of shutdown")
};
let main_listener = self.main_future.await?;
if let Err(_err) = self.extra_shutdown_sender.send(()) {
tracing::error!("Failed to notify http thread of shutdown")
};
let extra_listeners = self.extra_futures.await?;
Ok((main_listener, extra_listeners))
}
}
pub(crate) enum Listener {
Tcp(tokio::net::TcpListener),
#[cfg(unix)]
Unix(tokio::net::UnixListener),
Tls {
listener: tokio::net::TcpListener,
acceptor: tokio_rustls::TlsAcceptor,
},
}
pub(crate) enum NetworkStream {
Tcp(tokio::net::TcpStream),
#[cfg(unix)]
Unix(tokio::net::UnixStream),
Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
}
impl Listener {
pub(crate) async fn new_from_socket_addr(
address: SocketAddr,
tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
) -> Result<Self, ApolloRouterError> {
let listener = tokio::net::TcpListener::bind(address)
.await
.map_err(ApolloRouterError::ServerCreationError)?;
match tls_acceptor {
None => Ok(Listener::Tcp(listener)),
Some(acceptor) => Ok(Listener::Tls { listener, acceptor }),
}
}
pub(crate) fn new_from_listener(
listener: tokio::net::TcpListener,
tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
) -> Self {
match tls_acceptor {
None => Listener::Tcp(listener),
Some(acceptor) => Listener::Tls { listener, acceptor },
}
}
pub(crate) fn local_addr(&self) -> std::io::Result<ListenAddr> {
match self {
Listener::Tcp(listener) => listener.local_addr().map(Into::into),
#[cfg(unix)]
Listener::Unix(listener) => listener.local_addr().map(|addr| {
ListenAddr::UnixSocket(
addr.as_pathname()
.map(ToOwned::to_owned)
.unwrap_or_default(),
)
}),
Listener::Tls { listener, .. } => listener.local_addr().map(Into::into),
}
}
pub(crate) async fn accept(&mut self) -> std::io::Result<NetworkStream> {
match self {
Listener::Tcp(listener) => listener
.accept()
.await
.map(|(stream, _)| NetworkStream::Tcp(stream)),
#[cfg(unix)]
Listener::Unix(listener) => listener
.accept()
.await
.map(|(stream, _)| NetworkStream::Unix(stream)),
Listener::Tls { listener, acceptor } => {
let (stream, _) = listener.accept().await?;
Ok(NetworkStream::Tls(acceptor.accept(stream).await?))
}
}
}
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use std::str::FromStr;
use futures::channel::oneshot;
use test_log::test;
use super::*;
#[test(tokio::test)]
async fn sanity() {
let (shutdown_sender, shutdown_receiver) = oneshot::channel();
let (extra_shutdown_sender, extra_shutdown_receiver) = oneshot::channel();
let listener = Listener::Tcp(tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap());
let (all_connections_stopped_sender, _) = mpsc::channel::<()>(1);
HttpServerHandle::new(
shutdown_sender,
extra_shutdown_sender,
futures::future::ready(Ok(listener)).boxed(),
futures::future::ready(Ok(vec![])).boxed(),
Some(SocketAddr::from_str("127.0.0.1:0").unwrap().into()),
Default::default(),
all_connections_stopped_sender,
)
.shutdown()
.await
.expect("Should have waited for shutdown");
shutdown_receiver
.await
.expect("Should have been send notification to shutdown");
extra_shutdown_receiver
.await
.expect("Should have been send notification to shutdown");
}
#[test(tokio::test)]
#[cfg(unix)]
async fn sanity_unix() {
let temp_dir = tempfile::tempdir().unwrap();
let sock = temp_dir.as_ref().join("sock");
let (shutdown_sender, shutdown_receiver) = oneshot::channel();
let (extra_shutdown_sender, extra_shutdown_receiver) = oneshot::channel();
let listener = Listener::Unix(tokio::net::UnixListener::bind(&sock).unwrap());
let (all_connections_stopped_sender, _) = mpsc::channel::<()>(1);
HttpServerHandle::new(
shutdown_sender,
extra_shutdown_sender,
futures::future::ready(Ok(listener)).boxed(),
futures::future::ready(Ok(vec![])).boxed(),
Some(ListenAddr::UnixSocket(sock)),
Default::default(),
all_connections_stopped_sender,
)
.shutdown()
.await
.expect("Should have waited for shutdown");
shutdown_receiver
.await
.expect("Should have sent notification to shutdown");
extra_shutdown_receiver
.await
.expect("Should have sent notification to shutdown");
}
}