use axum::{routing::get, Router};
use essential_node::db;
use essential_node_types::block_notify::BlockRx;
use std::{io, net::SocketAddr};
use thiserror::Error;
use tokio::{
net::{TcpListener, TcpStream},
task::JoinSet,
};
use tower_http::cors::CorsLayer;
pub mod endpoint;
#[derive(Clone)]
pub struct State {
pub conn_pool: db::ConnectionPool,
pub new_block: Option<BlockRx>,
}
#[derive(Debug, Error)]
pub enum ServeNextConnError {
#[error("failed to acquire next connection: {0}")]
Next(#[from] io::Error),
#[error("{0}")]
Serve(#[from] ServeConnError),
}
#[derive(Debug, Error)]
#[error("Serve connection error: {0}")]
pub struct ServeConnError(#[from] Box<dyn std::error::Error + Send + Sync>);
pub const DEFAULT_CONNECTION_LIMIT: usize = 2_000;
pub async fn serve(router: &Router, listener: &TcpListener, conn_limit: usize) {
let mut conn_set = JoinSet::new();
loop {
serve_next_conn(router, listener, conn_limit, &mut conn_set).await;
}
}
#[tracing::instrument(skip_all)]
pub async fn serve_next_conn(
router: &Router,
listener: &TcpListener,
conn_limit: usize,
conn_set: &mut JoinSet<()>,
) {
let stream = match next_conn(listener, conn_limit, conn_set).await {
Ok((stream, _remote_addr)) => {
#[cfg(feature = "tracing")]
tracing::trace!("Accepted new connection from: {_remote_addr}");
stream
}
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::trace!("Failed to accept connection {_err}");
return;
}
};
let router = router.clone();
conn_set.spawn(async move {
if let Err(_err) = serve_conn(&router, stream).await {
#[cfg(feature = "tracing")]
tracing::trace!("Serve connection error: {_err}");
}
});
}
#[tracing::instrument(skip_all, err)]
pub async fn next_conn(
listener: &TcpListener,
conn_limit: usize,
conn_set: &mut JoinSet<()>,
) -> io::Result<(TcpStream, SocketAddr)> {
if conn_set.len() >= conn_limit {
#[cfg(feature = "tracing")]
tracing::info!("Connection limit reached: {conn_limit}");
conn_set.join_next().await.expect("set cannot be empty")?;
}
tracing::trace!("Awaiting new connection at {}", listener.local_addr()?);
listener.accept().await
}
#[tracing::instrument(skip_all, err)]
pub async fn serve_conn(router: &Router, stream: TcpStream) -> Result<(), ServeConnError> {
let stream = hyper_util::rt::TokioIo::new(stream);
let hyper_service = hyper::service::service_fn(
move |request: axum::extract::Request<hyper::body::Incoming>| {
tower::Service::call(&mut router.clone(), request)
},
);
let executor = hyper_util::rt::TokioExecutor::new();
let conn = hyper_util::server::conn::auto::Builder::new(executor).http2_only();
conn.serve_connection(stream, hyper_service)
.await
.map_err(ServeConnError)
}
pub fn router(state: State) -> Router {
with_endpoints(Router::new())
.layer(cors_layer())
.with_state(state)
}
pub fn with_endpoints(router: Router<State>) -> Router<State> {
use endpoint::*;
router
.route(health_check::PATH, get(health_check::handler))
.route(list_blocks::PATH, get(list_blocks::handler))
.route(query_state::PATH, get(query_state::handler))
.route(subscribe_blocks::PATH, get(subscribe_blocks::handler))
}
pub fn cors_layer() -> CorsLayer {
CorsLayer::new()
.allow_origin(tower_http::cors::Any)
.allow_methods([http::Method::GET, http::Method::OPTIONS])
.allow_headers([http::header::CONTENT_TYPE])
}