use std::net::SocketAddr;
use std::sync::Arc;
use axum_server::tls_rustls::RustlsConfig;
use axum_server::Handle;
use derive_more::derive::Display;
use derive_more::Constructor;
use futures::future::BoxFuture;
use thiserror::Error;
use tokio::sync::oneshot::{Receiver, Sender};
use torrust_tracker_configuration::AccessTokens;
use tracing::{instrument, Level};
use super::routes::router;
use crate::bootstrap::jobs::Started;
use crate::core::Tracker;
use crate::servers::apis::API_LOG_TARGET;
use crate::servers::custom_axum_server::{self, TimeoutAcceptor};
use crate::servers::logging::STARTED_ON;
use crate::servers::registar::{ServiceHealthCheckJob, ServiceRegistration, ServiceRegistrationForm};
use crate::servers::signals::{graceful_shutdown, Halted};
#[derive(Debug, Error)]
pub enum Error {
#[error("Error when starting or stopping the API server")]
FailedToStartOrStop(String),
}
#[allow(clippy::module_name_repetitions)]
pub type StoppedApiServer = ApiServer<Stopped>;
#[allow(clippy::module_name_repetitions)]
pub type RunningApiServer = ApiServer<Running>;
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Display)]
pub struct ApiServer<S>
where
S: std::fmt::Debug + std::fmt::Display,
{
pub state: S,
}
#[derive(Debug, Display)]
#[display("Stopped: {launcher}")]
pub struct Stopped {
launcher: Launcher,
}
#[derive(Debug, Display)]
#[display("Running (with local address): {local_addr}")]
pub struct Running {
pub local_addr: SocketAddr,
pub halt_task: tokio::sync::oneshot::Sender<Halted>,
pub task: tokio::task::JoinHandle<Launcher>,
}
impl Running {
#[must_use]
pub fn new(
local_addr: SocketAddr,
halt_task: tokio::sync::oneshot::Sender<Halted>,
task: tokio::task::JoinHandle<Launcher>,
) -> Self {
Self {
local_addr,
halt_task,
task,
}
}
}
impl ApiServer<Stopped> {
#[must_use]
pub fn new(launcher: Launcher) -> Self {
Self {
state: Stopped { launcher },
}
}
#[instrument(skip(self, tracker, form, access_tokens), err, ret(Display, level = Level::INFO))]
pub async fn start(
self,
tracker: Arc<Tracker>,
form: ServiceRegistrationForm,
access_tokens: Arc<AccessTokens>,
) -> Result<ApiServer<Running>, Error> {
let (tx_start, rx_start) = tokio::sync::oneshot::channel::<Started>();
let (tx_halt, rx_halt) = tokio::sync::oneshot::channel::<Halted>();
let launcher = self.state.launcher;
let task = tokio::spawn(async move {
tracing::debug!(target: API_LOG_TARGET, "Starting with launcher in spawned task ...");
let _task = launcher.start(tracker, access_tokens, tx_start, rx_halt).await;
tracing::debug!(target: API_LOG_TARGET, "Started with launcher in spawned task");
launcher
});
let api_server = match rx_start.await {
Ok(started) => {
form.send(ServiceRegistration::new(started.address, check_fn))
.expect("it should be able to send service registration");
ApiServer {
state: Running::new(started.address, tx_halt, task),
}
}
Err(err) => {
let msg = format!("Unable to start API server: {err}");
tracing::error!("{}", msg);
panic!("{}", msg);
}
};
Ok(api_server)
}
}
impl ApiServer<Running> {
#[instrument(skip(self), err, ret(Display, level = Level::INFO))]
pub async fn stop(self) -> Result<ApiServer<Stopped>, Error> {
self.state
.halt_task
.send(Halted::Normal)
.map_err(|_| Error::FailedToStartOrStop("Task killer channel was closed.".to_string()))?;
let launcher = self.state.task.await.map_err(|e| Error::FailedToStartOrStop(e.to_string()))?;
Ok(ApiServer {
state: Stopped { launcher },
})
}
}
#[must_use]
#[instrument(skip())]
pub fn check_fn(binding: &SocketAddr) -> ServiceHealthCheckJob {
let url = format!("http://{binding}/api/health_check");
let info = format!("checking api health check at: {url}");
let job = tokio::spawn(async move {
match reqwest::get(url).await {
Ok(response) => Ok(response.status().to_string()),
Err(err) => Err(err.to_string()),
}
});
ServiceHealthCheckJob::new(*binding, info, job)
}
#[derive(Constructor, Debug)]
pub struct Launcher {
bind_to: SocketAddr,
tls: Option<RustlsConfig>,
}
impl std::fmt::Display for Launcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.tls.is_some() {
write!(f, "(with socket): {}, using TLS", self.bind_to,)
} else {
write!(f, "(with socket): {}, without TLS", self.bind_to,)
}
}
}
impl Launcher {
#[instrument(skip(self, tracker, access_tokens, tx_start, rx_halt))]
pub fn start(
&self,
tracker: Arc<Tracker>,
access_tokens: Arc<AccessTokens>,
tx_start: Sender<Started>,
rx_halt: Receiver<Halted>,
) -> BoxFuture<'static, ()> {
let router = router(tracker, access_tokens);
let socket = std::net::TcpListener::bind(self.bind_to).expect("Could not bind tcp_listener to address.");
let address = socket.local_addr().expect("Could not get local_addr from tcp_listener.");
let handle = Handle::new();
tokio::task::spawn(graceful_shutdown(
handle.clone(),
rx_halt,
format!("Shutting down tracker API server on socket address: {address}"),
));
let tls = self.tls.clone();
let protocol = if tls.is_some() { "https" } else { "http" };
tracing::info!(target: API_LOG_TARGET, "Starting on {protocol}://{}", address);
let running = Box::pin(async {
match tls {
Some(tls) => custom_axum_server::from_tcp_rustls_with_timeouts(socket, tls)
.handle(handle)
.serve(router.into_make_service_with_connect_info::<std::net::SocketAddr>())
.await
.expect("Axum server for tracker API crashed."),
None => custom_axum_server::from_tcp_with_timeouts(socket)
.handle(handle)
.acceptor(TimeoutAcceptor)
.serve(router.into_make_service_with_connect_info::<std::net::SocketAddr>())
.await
.expect("Axum server for tracker API crashed."),
}
});
tracing::info!(target: API_LOG_TARGET, "{STARTED_ON} {protocol}://{}", address);
tx_start
.send(Started { address })
.expect("the HTTP(s) Tracker API service should not be dropped");
running
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use torrust_tracker_test_helpers::configuration::ephemeral_public;
use crate::bootstrap::app::initialize_with_configuration;
use crate::bootstrap::jobs::make_rust_tls;
use crate::servers::apis::server::{ApiServer, Launcher};
use crate::servers::registar::Registar;
#[tokio::test]
async fn it_should_be_able_to_start_and_stop() {
let cfg = Arc::new(ephemeral_public());
let config = &cfg.http_api.clone().unwrap();
let tracker = initialize_with_configuration(&cfg);
let bind_to = config.bind_address;
let tls = make_rust_tls(&config.tsl_config)
.await
.map(|tls| tls.expect("tls config failed"));
let access_tokens = Arc::new(config.access_tokens.clone());
let stopped = ApiServer::new(Launcher::new(bind_to, tls));
let register = &Registar::default();
let started = stopped
.start(tracker, register.give_form(), access_tokens)
.await
.expect("it should start the server");
let stopped = started.stop().await.expect("it should stop the server");
assert_eq!(stopped.state.launcher.bind_to, bind_to);
}
}