use super::signal_gate::SignalGate;
use super::task::{Executor, Spawner};
use super::{Server, handle_socket};
use crate::tracing::{self, Instrument};
use async_net::{AsyncToSocketAddrs, TcpListener};
#[cfg(feature = "tls")]
use rustls::ServerConfig;
use std::io::ErrorKind;
pub fn serve<F, T>(addr: impl AsyncToSocketAddrs, factory: F)
where
F: FnOnce() -> T,
T: Server,
{
let server = factory();
let executor = Executor::new();
let spawner = executor.spawner();
async_io::block_on(executor.run(async {
let listener = TcpListener::bind(addr)
.await
.expect("Failed to bind to socket");
accept_loop(listener, &server, spawner).await;
}));
}
async fn accept_loop<'server: 'exec, 'exec, S: Server>(
listener: TcpListener,
server: &'server S,
spawner: Spawner<'exec>,
) {
let mut gate = SignalGate::new();
loop {
match gate.or_signal(listener.accept()).await {
Ok((socket, _addr)) => {
let _ = socket.set_nodelay(true);
spawner.spawn(
handle_socket(socket, server, spawner)
.instrument(tracing::info_span!("conn", remote = %_addr)),
);
}
Err(error) => {
if error.kind() == ErrorKind::Interrupted {
tracing::debug!(?error, "Got signal to shut down");
} else {
tracing::error!(?error, "Failed to accept connection, shutting down");
}
break;
}
}
}
gate.wait_for_shutdown(spawner).await;
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub fn serve_tls<F, T>(addr: impl AsyncToSocketAddrs, tls_config: ServerConfig, factory: F)
where
F: FnOnce() -> T,
T: Server,
{
let server = factory();
let executor = Executor::new();
let spawner = executor.spawner();
async_io::block_on(executor.run(async {
let listener = TcpListener::bind(addr)
.await
.expect("Failed to bind to socket");
accept_loop_tls(listener, tls_config, &server, spawner).await;
}));
}
#[cfg(feature = "tls")]
async fn accept_loop_tls<'server: 'exec, 'exec, S: Server>(
listener: TcpListener,
tls_config: ServerConfig,
server: &'server S,
spawner: Spawner<'exec>,
) {
use futures_rustls::TlsAcceptor;
use std::sync::Arc;
let mut gate = SignalGate::new();
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
loop {
match gate.or_signal(listener.accept()).await {
Ok((socket, _addr)) => {
let _ = socket.set_nodelay(true);
let acceptor = acceptor.clone();
spawner.spawn(async move {
match acceptor.accept(socket).await {
Ok(tls_stream) => {
handle_socket(tls_stream, server, spawner)
.instrument(tracing::info_span!("conn", remote = %_addr))
.await;
}
Err(_err) => {
tracing::debug!(error = ?_err, "TLS handshake failed");
}
}
});
}
Err(error) => {
if error.kind() == ErrorKind::Interrupted {
tracing::debug!(?error, "Got signal to shut down");
} else {
tracing::error!(?error, "Failed to accept connection, shutting down");
}
break;
}
}
}
gate.wait_for_shutdown(spawner).await;
}
#[cfg(test)]
mod test {
use super::*;
use crate::http::{Body, Request, Response};
use futures_lite::{AsyncReadExt, AsyncWriteExt};
struct MockServer;
impl Server for MockServer {
async fn route<'s: 'e, 'e>(
&'s self,
req: Request<'_, '_>,
_spawner: Spawner<'e>,
) -> Response {
if req.path == "/error" {
panic!("forced panic");
}
Response::ok().with_body(format!("Hello {}", req.path), Body::DEFAULT_CONTENT_TYPE)
}
}
#[test]
fn test_serve() {
use async_net::TcpListener;
use std::time::Duration;
let addr = futures_lite::future::block_on(async {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener.local_addr().unwrap()
});
std::thread::spawn(move || {
serve(addr, || MockServer);
});
std::thread::sleep(Duration::from_millis(100));
futures_lite::future::block_on(async {
let mut stream = async_net::TcpStream::connect(addr).await.unwrap();
stream
.write_all(b"GET /serve HTTP/1.1\r\n\r\n")
.await
.unwrap();
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
let response = std::str::from_utf8(&buf[..n]).unwrap();
assert!(response.contains("Hello /serve"));
});
}
#[test]
#[cfg(feature = "tls")]
fn test_serve_tls() {
use async_net::TcpListener;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName};
use std::sync::Arc;
use std::time::Duration;
let subject_alt_names = vec!["127.0.0.1".to_string(), "localhost".to_string()];
let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
let cert_der = cert.cert.der().to_vec();
let key_der = cert.signing_key.serialize_der();
let certs = vec![CertificateDer::from(cert_der.clone())];
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_der));
let server_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap();
let addr = futures_lite::future::block_on(async {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener.local_addr().unwrap()
});
let server_config_clone = server_config.clone();
std::thread::spawn(move || {
serve_tls(addr, server_config_clone, || MockServer);
});
std::thread::sleep(Duration::from_millis(100));
futures_lite::future::block_on(async {
let mut root_store = rustls::RootCertStore::empty();
root_store.add(CertificateDer::from(cert_der)).unwrap();
let client_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = futures_rustls::TlsConnector::from(Arc::new(client_config));
let stream = async_net::TcpStream::connect(addr).await.unwrap();
let domain = ServerName::try_from("127.0.0.1").unwrap();
let mut tls_stream = connector.connect(domain, stream).await.unwrap();
tls_stream
.write_all(b"GET /serve_tls HTTP/1.1\r\n\r\n")
.await
.unwrap();
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
let response = std::str::from_utf8(&buf[..n]).unwrap();
assert!(response.contains("Hello /serve_tls"));
});
}
}