essential_builder_api/
lib.rs1use axum::{
8 routing::{get, post},
9 Router,
10};
11use essential_builder_db as db;
12use std::{io, net::SocketAddr};
13use thiserror::Error;
14use tokio::{
15 net::{TcpListener, TcpStream},
16 task::JoinSet,
17};
18use tower_http::cors::CorsLayer;
19
20pub mod endpoint;
21
22#[derive(Clone)]
24pub struct State {
25 pub conn_pool: db::ConnectionPool,
27}
28
29#[derive(Debug, Error)]
31pub enum ServeNextConnError {
32 #[error("failed to acquire next connection: {0}")]
34 Next(#[from] io::Error),
35 #[error("{0}")]
37 Serve(#[from] ServeConnError),
38}
39
40#[derive(Debug, Error)]
42#[error("Serve connection error: {0}")]
43pub struct ServeConnError(#[from] Box<dyn std::error::Error + Send + Sync>);
44
45pub const DEFAULT_CONNECTION_LIMIT: usize = 2_000;
48
49pub async fn serve(router: &Router, listener: &TcpListener, conn_limit: usize) {
58 let mut conn_set = JoinSet::new();
59 loop {
60 serve_next_conn(router, listener, conn_limit, &mut conn_set).await;
61 }
62}
63
64#[tracing::instrument(skip_all)]
91pub async fn serve_next_conn(
92 router: &Router,
93 listener: &TcpListener,
94 conn_limit: usize,
95 conn_set: &mut JoinSet<()>,
96) {
97 let stream = match next_conn(listener, conn_limit, conn_set).await {
99 Ok((stream, _remote_addr)) => {
100 #[cfg(feature = "tracing")]
101 tracing::trace!("Accepted new connection from: {_remote_addr}");
102 stream
103 }
104 Err(_err) => {
105 #[cfg(feature = "tracing")]
106 tracing::trace!("Failed to accept connection {_err}");
107 return;
108 }
109 };
110
111 let router = router.clone();
113 conn_set.spawn(async move {
114 if let Err(_err) = serve_conn(&router, stream).await {
115 #[cfg(feature = "tracing")]
116 tracing::trace!("Serve connection error: {_err}");
117 }
118 });
119}
120
121#[tracing::instrument(skip_all, err)]
126pub async fn next_conn(
127 listener: &TcpListener,
128 conn_limit: usize,
129 conn_set: &mut JoinSet<()>,
130) -> io::Result<(TcpStream, SocketAddr)> {
131 if conn_set.len() >= conn_limit {
133 #[cfg(feature = "tracing")]
134 tracing::info!("Connection limit reached: {conn_limit}");
135 conn_set.join_next().await.expect("set cannot be empty")?;
136 }
137 tracing::trace!("Awaiting new connection at {}", listener.local_addr()?);
139 listener.accept().await
140}
141
142#[tracing::instrument(skip_all, err)]
144pub async fn serve_conn(router: &Router, stream: TcpStream) -> Result<(), ServeConnError> {
145 let stream = hyper_util::rt::TokioIo::new(stream);
148
149 let hyper_service = hyper::service::service_fn(
153 move |request: axum::extract::Request<hyper::body::Incoming>| {
154 tower::Service::call(&mut router.clone(), request)
155 },
156 );
157
158 let executor = hyper_util::rt::TokioExecutor::new();
160 let conn = hyper_util::server::conn::auto::Builder::new(executor).http2_only();
161 conn.serve_connection(stream, hyper_service)
162 .await
163 .map_err(ServeConnError)
164}
165
166pub fn router(state: State) -> Router {
169 with_endpoints(Router::new())
170 .layer(cors_layer())
171 .with_state(state)
172}
173
174pub fn with_endpoints(router: Router<State>) -> Router<State> {
176 use endpoint::*;
177 router
178 .route(health_check::PATH, get(health_check::handler))
179 .route(
180 latest_solution_set_failures::PATH,
181 get(latest_solution_set_failures::handler),
182 )
183 .route(
184 list_solution_set_failures::PATH,
185 get(list_solution_set_failures::handler),
186 )
187 .route(
188 submit_solution_set::PATH,
189 post(submit_solution_set::handler),
190 )
191}
192
193pub fn cors_layer() -> CorsLayer {
195 CorsLayer::new()
196 .allow_origin(tower_http::cors::Any)
197 .allow_methods([http::Method::GET, http::Method::OPTIONS, http::Method::POST])
198 .allow_headers([http::header::CONTENT_TYPE])
199}