use std::net::SocketAddr;
use std::sync::Arc;
use axum_server::tls_rustls::RustlsConfig;
use axum_server::Handle;
use derive_more::Constructor;
use futures::future::BoxFuture;
use tokio::sync::oneshot::{Receiver, Sender};
use tracing::instrument;
use super::v1::routes::router;
use crate::bootstrap::jobs::Started;
use crate::core::Tracker;
use crate::servers::custom_axum_server::{self, TimeoutAcceptor};
use crate::servers::http::HTTP_TRACKER_LOG_TARGET;
use crate::servers::logging::STARTED_ON;
use crate::servers::registar::{ServiceHealthCheckJob, ServiceRegistration, ServiceRegistrationForm};
use crate::servers::signals::{graceful_shutdown, Halted};
#[derive(Debug)]
pub enum Error {
Error(String),
}
#[derive(Constructor, Debug)]
pub struct Launcher {
pub bind_to: SocketAddr,
pub tls: Option<RustlsConfig>,
}
impl Launcher {
#[instrument(skip(self, tracker, tx_start, rx_halt))]
fn start(&self, tracker: Arc<Tracker>, tx_start: Sender<Started>, rx_halt: Receiver<Halted>) -> BoxFuture<'static, ()> {
let socket = std::net::TcpListener::bind(self.bind_to).expect("Could not bind tcp_listener to address.");
let address = socket.local_addr().expect("Could not get local_addr from tcp_listener.");
let handle = Handle::new();
tokio::task::spawn(graceful_shutdown(
handle.clone(),
rx_halt,
format!("Shutting down HTTP server on socket address: {address}"),
));
let tls = self.tls.clone();
let protocol = if tls.is_some() { "https" } else { "http" };
tracing::info!(target: HTTP_TRACKER_LOG_TARGET, "Starting on: {protocol}://{}", address);
let app = router(tracker, address);
let running = Box::pin(async {
match tls {
Some(tls) => custom_axum_server::from_tcp_rustls_with_timeouts(socket, tls)
.handle(handle)
.serve(app.into_make_service_with_connect_info::<std::net::SocketAddr>())
.await
.expect("Axum server crashed."),
None => custom_axum_server::from_tcp_with_timeouts(socket)
.handle(handle)
.acceptor(TimeoutAcceptor)
.serve(app.into_make_service_with_connect_info::<std::net::SocketAddr>())
.await
.expect("Axum server crashed."),
}
});
tracing::info!(target: HTTP_TRACKER_LOG_TARGET, "{STARTED_ON}: {protocol}://{}", address);
tx_start
.send(Started { address })
.expect("the HTTP(s) Tracker service should not be dropped");
running
}
}
#[allow(clippy::module_name_repetitions)]
pub type StoppedHttpServer = HttpServer<Stopped>;
#[allow(clippy::module_name_repetitions)]
pub type RunningHttpServer = HttpServer<Running>;
#[allow(clippy::module_name_repetitions)]
pub struct HttpServer<S> {
pub state: S,
}
pub struct Stopped {
launcher: Launcher,
}
pub struct Running {
pub binding: SocketAddr,
pub halt_task: tokio::sync::oneshot::Sender<Halted>,
pub task: tokio::task::JoinHandle<Launcher>,
}
impl HttpServer<Stopped> {
#[must_use]
pub fn new(launcher: Launcher) -> Self {
Self {
state: Stopped { launcher },
}
}
pub async fn start(self, tracker: Arc<Tracker>, form: ServiceRegistrationForm) -> Result<HttpServer<Running>, Error> {
let (tx_start, rx_start) = tokio::sync::oneshot::channel::<Started>();
let (tx_halt, rx_halt) = tokio::sync::oneshot::channel::<Halted>();
let launcher = self.state.launcher;
let task = tokio::spawn(async move {
let server = launcher.start(tracker, tx_start, rx_halt);
server.await;
launcher
});
let binding = rx_start.await.expect("it should be able to start the service").address;
form.send(ServiceRegistration::new(binding, check_fn))
.expect("it should be able to send service registration");
Ok(HttpServer {
state: Running {
binding,
halt_task: tx_halt,
task,
},
})
}
}
impl HttpServer<Running> {
pub async fn stop(self) -> Result<HttpServer<Stopped>, Error> {
self.state
.halt_task
.send(Halted::Normal)
.map_err(|_| Error::Error("Task killer channel was closed.".to_string()))?;
let launcher = self.state.task.await.map_err(|e| Error::Error(e.to_string()))?;
Ok(HttpServer {
state: Stopped { launcher },
})
}
}
#[must_use]
pub fn check_fn(binding: &SocketAddr) -> ServiceHealthCheckJob {
let url = format!("http://{binding}/health_check");
let info = format!("checking http tracker health check at: {url}");
let job = tokio::spawn(async move {
match reqwest::get(url).await {
Ok(response) => Ok(response.status().to_string()),
Err(err) => Err(err.to_string()),
}
});
ServiceHealthCheckJob::new(*binding, info, job)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use torrust_tracker_test_helpers::configuration::ephemeral_public;
use crate::bootstrap::app::initialize_with_configuration;
use crate::bootstrap::jobs::make_rust_tls;
use crate::servers::http::server::{HttpServer, Launcher};
use crate::servers::registar::Registar;
#[tokio::test]
async fn it_should_be_able_to_start_and_stop() {
let cfg = Arc::new(ephemeral_public());
let tracker = initialize_with_configuration(&cfg);
let http_trackers = cfg.http_trackers.clone().expect("missing HTTP trackers configuration");
let config = &http_trackers[0];
let bind_to = config.bind_address;
let tls = make_rust_tls(&config.tsl_config)
.await
.map(|tls| tls.expect("tls config failed"));
let register = &Registar::default();
let stopped = HttpServer::new(Launcher::new(bind_to, tls));
let started = stopped
.start(tracker, register.give_form())
.await
.expect("it should start the server");
let stopped = started.stop().await.expect("it should stop the server");
assert_eq!(stopped.state.launcher.bind_to, bind_to);
}
}