use apimock_config::Config;
use apimock_routing::ParsedRequest;
use console::style;
use http_body_util::{BodyExt, Empty};
use hyper::{
HeaderMap, Response, body,
header::{CONTENT_LENGTH, HeaderValue},
service::service_fn,
};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use rustls::ServerConfig;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio_rustls::TlsAcceptor;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use crate::{
dyn_route::dyn_route_content,
error::{ServerError, ServerResult},
middleware::LoadedMiddlewares,
parsed_request::{capture_in_log, parsed_request_from},
respond_response::respond_response,
response::error_response::internal_server_error_response,
response_handler::default_response_headers,
tls::{load_certs, load_private_key},
types::BoxBody,
};
pub use crate::control::{ReloadHint, ServerControl, ServerHandle, ServerState};
#[derive(Clone)]
pub struct AppState {
pub config: Config,
pub middlewares: LoadedMiddlewares,
}
pub struct Server {
pub app_state: AppState,
pub http_addr: Option<SocketAddr>,
pub https_addr: Option<SocketAddr>,
}
impl Server {
pub async fn new(config: Config) -> ServerResult<Self> {
let http_addr = resolve_listener(config.listener_http_addr().as_deref())?;
let https_addr = resolve_listener(config.listener_https_addr().as_deref())?;
let relative_dir_path = config
.current_dir_to_parent_dir_relative_path()
.map_err(ServerError::Config)?;
let middlewares = LoadedMiddlewares::compile(
config.service.middlewares_file_paths.as_deref().unwrap_or(&[]),
relative_dir_path.as_str(),
)?;
if !middlewares.is_empty() {
log::info!("middleware is activated: {} file(s)", middlewares.len());
}
Ok(Server {
http_addr,
https_addr,
app_state: AppState {
config,
middlewares,
},
})
}
pub async fn start(&self) {
let http = self.http_start();
let https = self.https_start();
tokio::join!(http, https);
}
async fn http_start(&self) {
let Some(addr) = self.http_addr else {
return;
};
let listener = match TcpListener::bind(addr).await {
Ok(l) => l,
Err(err) => {
log::error!("failed to bind HTTP listener at {}: {}", addr, err);
return;
}
};
log::info!(
"Greetings from apimock-rs (API Mock) !!\nListening on {} ...\n",
style(format!("http://{}", addr)).cyan()
);
let app_state = Arc::new(Mutex::new(self.app_state.clone()));
loop {
let (stream, _) = match listener.accept().await {
Ok(pair) => pair,
Err(err) => {
log::error!("HTTP accept failed: {}", err);
continue;
}
};
let io = TokioIo::new(stream);
let app_state = app_state.clone();
tokio::task::spawn(async move {
if let Err(err) = Builder::new(TokioExecutor::new())
.serve_connection(
io,
service_fn(move |request: hyper::Request<body::Incoming>| {
service(request, app_state.clone())
}),
)
.await
{
log::error!("{} to build connection: {:?}", style("failed").red(), err);
}
});
}
}
async fn https_start(&self) {
let Some(addr) = self.https_addr else {
return;
};
let tls = match self
.app_state
.config
.listener
.as_ref()
.and_then(|l| l.tls.as_ref())
{
Some(t) => t.clone(),
None => {
log::error!("internal: HTTPS listener scheduled without TLS config");
return;
}
};
let certs = match load_certs(tls.cert.as_str()) {
Ok(c) => c,
Err(err) => {
log::error!("{}", err);
return;
}
};
let key = match load_private_key(tls.key.as_str()) {
Ok(k) => k,
Err(err) => {
log::error!("{}", err);
return;
}
};
let mut config = match ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
{
Ok(c) => c,
Err(err) => {
log::error!("failed to build rustls ServerConfig: {}", err);
return;
}
};
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let acceptor = TlsAcceptor::from(Arc::new(config));
let listener = match TcpListener::bind(addr).await {
Ok(l) => l,
Err(err) => {
log::error!("failed to bind HTTPS listener at {}: {}", addr, err);
return;
}
};
log::info!(
"Greetings from apimock-rs (API Mock) !!\nListening on {} ...\n",
style(format!("https://{}", addr)).cyan()
);
let app_state = Arc::new(Mutex::new(self.app_state.clone()));
loop {
let (stream, _) = match listener.accept().await {
Ok(pair) => pair,
Err(err) => {
log::error!("HTTPS accept failed: {}", err);
continue;
}
};
let acceptor = acceptor.clone();
let app_state = app_state.clone();
tokio::spawn(async move {
let tls_stream = match acceptor.accept(stream).await {
Ok(s) => s,
Err(e) => {
log::error!("TLS handshake failed: {:?}", e);
return;
}
};
let io = TokioIo::new(tls_stream);
let app_state = app_state.clone();
tokio::task::spawn(async move {
if let Err(err) = Builder::new(TokioExecutor::new())
.serve_connection(
io,
service_fn(move |request: hyper::Request<body::Incoming>| {
service(request, app_state.clone())
}),
)
.await
{
log::error!("{} to build connection: {:?}", style("failed").red(), err);
}
});
});
}
}
}
fn resolve_listener(addr_str: Option<&str>) -> ServerResult<Option<SocketAddr>> {
let Some(addr_str) = addr_str else {
return Ok(None);
};
let mut addrs = addr_str
.to_socket_addrs()
.map_err(|e| ServerError::ListenerAddress {
addr: addr_str.to_owned(),
reason: e.to_string(),
})?;
addrs
.next()
.map(Some)
.ok_or_else(|| ServerError::ListenerAddress {
addr: addr_str.to_owned(),
reason: "address resolved to no socket addresses".to_owned(),
})
}
pub async fn service(
request: hyper::Request<body::Incoming>,
app_state: Arc<Mutex<AppState>>,
) -> Result<hyper::Response<BoxBody>, hyper::http::Error> {
let request_headers = request.headers().clone();
if request.method() == hyper::Method::OPTIONS {
return handle_options(&request_headers);
}
let parsed_request = match parsed_request_from(request).await {
Ok(x) => x,
Err(err) => return internal_server_error_response(err.as_str(), &request_headers),
};
let shared_app_state = { app_state.lock().await.clone() };
let config = shared_app_state.config;
let middlewares = shared_app_state.middlewares;
capture_in_log(&parsed_request, config.log.clone().unwrap_or_default().verbose);
if let Some(response) = middleware_response(&middlewares, &parsed_request).await {
return response;
}
if let Some(response) = rule_set_response(&config, &parsed_request).await {
return response;
}
dyn_route_content(
parsed_request.url_path.as_str(),
config.service.fallback_respond_dir.as_str(),
&request_headers,
)
.await
}
async fn middleware_response(
middlewares: &LoadedMiddlewares,
parsed_request: &ParsedRequest,
) -> Option<Result<hyper::Response<BoxBody>, hyper::http::Error>> {
for handler in middlewares.iter() {
match handler
.handle(
parsed_request.url_path.as_str(),
parsed_request.body_json.as_ref(),
&parsed_request.component_parts.headers,
)
.await
{
Some(x) => return Some(x),
None => continue,
}
}
None
}
async fn rule_set_response(
config: &Config,
parsed_request: &ParsedRequest,
) -> Option<Result<hyper::Response<BoxBody>, hyper::http::Error>> {
for (rule_set_idx, rule_set) in config.service.rule_sets.iter().enumerate() {
if let Some(respond) =
rule_set.find_matched(parsed_request, config.service.strategy.as_ref(), rule_set_idx)
{
let dir_prefix = rule_set.dir_prefix();
return Some(respond_response(&respond, dir_prefix.as_str(), parsed_request).await);
}
}
None
}
fn handle_options(
request_headers: &HeaderMap,
) -> Result<hyper::Response<BoxBody>, hyper::http::Error> {
let mut response = Response::new(Empty::new().boxed());
*response.status_mut() = hyper::StatusCode::NO_CONTENT;
response
.headers_mut()
.insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
for (header_key, header_value) in default_response_headers(request_headers).into_iter() {
if let Some(header_key) = header_key {
response.headers_mut().insert(header_key, header_value);
}
}
Ok(response)
}