use std::collections::HashSet;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use parking_lot::RwLock;
use http_body_util::{BodyExt, Full};
use hyper::body::{Bytes, Incoming};
#[cfg(feature = "https")]
use hyper::header::{ALT_SVC, HeaderValue};
use hyper::header::{AUTHORIZATION, CONTENT_TYPE, WWW_AUTHENTICATE};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use mechanics_core::job::MechanicsJob;
use tokio::net::TcpListener;
#[cfg(feature = "https")]
use tower::service_fn as tower_service_fn;
use mechanics_core::MechanicsPool;
#[cfg(feature = "https")]
mod tls;
#[cfg(feature = "https")]
pub use tls::TlsConfig;
#[cfg(feature = "https")]
pub use mechanics_http_server::Http3ServerConfig;
pub use mechanics_core::MechanicsPoolConfig;
type HttpResponse = Response<Full<Bytes>>;
enum ApiError {
NotFound,
Unauthorized,
InvalidType,
InvalidRequest,
Pool(String),
Internal,
}
impl ApiError {
fn to_response(&self) -> HttpResponse {
let (status, message) = match self {
Self::NotFound => (StatusCode::NOT_FOUND, "Not found".to_string()),
Self::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()),
Self::InvalidType => (StatusCode::BAD_REQUEST, "Invalid type".to_string()),
Self::InvalidRequest => (StatusCode::BAD_REQUEST, "Invalid request".to_string()),
Self::Pool(err) => (StatusCode::BAD_REQUEST, err.clone()),
Self::Internal => (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error".to_string(),
),
};
let mut response = json_response(status, &serde_json::json!({ "error": message }));
if matches!(self, Self::Unauthorized) {
response.headers_mut().insert(
WWW_AUTHENTICATE,
hyper::header::HeaderValue::from_static("Bearer"),
);
}
response
}
#[cfg(feature = "https")]
fn to_h3_response(&self) -> Response<Full<Bytes>> {
let (status, message) = match self {
Self::NotFound => (StatusCode::NOT_FOUND, "Not found".to_string()),
Self::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()),
Self::InvalidType => (StatusCode::BAD_REQUEST, "Invalid type".to_string()),
Self::InvalidRequest => (StatusCode::BAD_REQUEST, "Invalid request".to_string()),
Self::Pool(err) => (StatusCode::BAD_REQUEST, err.clone()),
Self::Internal => (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error".to_string(),
),
};
let mut response = json_response_bytes(status, &serde_json::json!({ "error": message }));
if matches!(self, Self::Unauthorized) {
response.headers_mut().insert(
WWW_AUTHENTICATE,
hyper::header::HeaderValue::from_static("Bearer"),
);
}
response
}
}
#[cfg(feature = "https")]
fn with_hsts(mut response: HttpResponse) -> HttpResponse {
response.headers_mut().insert(
hyper::header::STRICT_TRANSPORT_SECURITY,
hyper::header::HeaderValue::from_static("max-age=63072000"),
);
response
}
fn json_response(status: StatusCode, value: &serde_json::Value) -> HttpResponse {
let body = serde_json::to_vec(value).unwrap_or_else(|_| b"{}".to_vec());
let mut response = Response::new(Full::new(Bytes::from(body)));
*response.status_mut() = status;
response.headers_mut().insert(
CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
);
response
}
#[cfg(feature = "https")]
fn json_response_bytes(status: StatusCode, value: &serde_json::Value) -> Response<Full<Bytes>> {
let body = serde_json::to_vec(value).unwrap_or_else(|_| b"{}".to_vec());
let mut response = Response::new(Full::new(Bytes::from(body)));
*response.status_mut() = status;
response.headers_mut().insert(
CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
);
response
}
fn has_json_content_type(req: &Request<Incoming>) -> bool {
req.headers()
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(|value| {
value
.split(';')
.next()
.is_some_and(|mime| mime.trim().eq_ignore_ascii_case("application/json"))
})
.unwrap_or(false)
}
fn parse_bearer_token(header_value: &str) -> Option<&str> {
let mut parts = header_value.split_whitespace();
let scheme = parts.next()?;
if !scheme.eq_ignore_ascii_case("bearer") {
return None;
}
let token = parts.next()?;
if token.is_empty() || parts.next().is_some() {
return None;
}
Some(token)
}
fn bearer_token(req: &Request<Incoming>) -> Option<&str> {
req.headers()
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.and_then(parse_bearer_token)
}
#[cfg(feature = "https")]
fn bearer_token_from_headers(headers: &hyper::HeaderMap) -> Option<&str> {
headers
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.and_then(parse_bearer_token)
}
fn is_authorized(tokens: &RwLock<HashSet<String>>, req: &Request<Incoming>) -> bool {
let Some(token) = bearer_token(req) else {
return false;
};
tokens.read().contains(token)
}
#[cfg(feature = "https")]
fn is_authorized_headers(tokens: &RwLock<HashSet<String>>, headers: &hyper::HeaderMap) -> bool {
let Some(token) = bearer_token_from_headers(headers) else {
return false;
};
tokens.read().contains(token)
}
async fn parse_json_job(req: Request<Incoming>) -> Result<MechanicsJob, ApiError> {
if !has_json_content_type(&req) {
return Err(ApiError::InvalidType);
}
let body = req
.into_body()
.collect()
.await
.map_err(|_| ApiError::InvalidRequest)?
.to_bytes();
serde_json::from_slice(&body).map_err(|_| ApiError::InvalidRequest)
}
async fn execute_job(
pool: Arc<MechanicsPool>,
job: MechanicsJob,
) -> Result<serde_json::Value, ApiError> {
let task = tokio::task::spawn_blocking(move || pool.run(job));
let run_result = task.await.map_err(|_| ApiError::Internal)?;
run_result.map_err(|error| ApiError::Pool(error.to_string()))
}
async fn handle_request(
pool: Arc<MechanicsPool>,
tokens: Arc<RwLock<HashSet<String>>>,
req: Request<Incoming>,
) -> Result<HttpResponse, Infallible> {
if req.method() != Method::POST || req.uri().path() != "/api/v1/mechanics" {
return Ok(ApiError::NotFound.to_response());
}
if !is_authorized(&tokens, &req) {
return Ok(ApiError::Unauthorized.to_response());
}
let job = match parse_json_job(req).await {
Ok(job) => job,
Err(error) => return Ok(error.to_response()),
};
match execute_job(pool, job).await {
Ok(result) => Ok(json_response(StatusCode::OK, &result)),
Err(error) => Ok(error.to_response()),
}
}
#[cfg(feature = "https")]
async fn handle_h3_request(
tokens: Arc<RwLock<HashSet<String>>>,
req: Request<mechanics_http_server::H3RequestBody>,
) -> Result<Response<Full<Bytes>>, Infallible> {
if req.method() != Method::POST || req.uri().path() != "/api/v1/mechanics" {
return Ok(ApiError::NotFound.to_h3_response());
}
if !is_authorized_headers(&tokens, req.headers()) {
return Ok(ApiError::Unauthorized.to_h3_response());
}
Ok(ApiError::InvalidRequest.to_h3_response())
}
#[cfg(feature = "https")]
fn with_alt_svc(mut response: HttpResponse, h3_port: u16, max_age_secs: u64) -> HttpResponse {
let value = format!("h3=\":{h3_port}\"; ma={max_age_secs}");
let header = HeaderValue::from_str(&value)
.unwrap_or_else(|_| HeaderValue::from_static("h3=\":443\"; ma=86400"));
response.headers_mut().insert(ALT_SVC, header);
response
}
#[derive(Clone)]
pub struct MechanicsServer {
pool: Arc<MechanicsPool>,
tokens: Arc<RwLock<HashSet<String>>>,
}
impl MechanicsServer {
pub fn new(config: MechanicsPoolConfig) -> std::io::Result<Self> {
let pool = Arc::new(MechanicsPool::new(config).map_err(std::io::Error::other)?);
Ok(Self {
pool,
tokens: Arc::new(RwLock::default()),
})
}
pub fn add_token(&self, token: String) {
let token = token.trim();
if token.is_empty() {
return;
}
self.tokens.write().insert(token.to_string());
}
pub fn replace_tokens<I>(&self, tokens: I)
where
I: IntoIterator<Item = String>,
{
let new: HashSet<String> = tokens
.into_iter()
.map(|t| t.trim().to_string())
.filter(|t| !t.is_empty())
.collect();
*self.tokens.write() = new;
}
pub(crate) fn pool(&self) -> Arc<MechanicsPool> {
Arc::clone(&self.pool)
}
pub fn run(&self, bind_addr: SocketAddr) -> std::io::Result<()> {
let std_listener = std::net::TcpListener::bind(bind_addr)?;
std_listener.set_nonblocking(true)?;
let server = self.clone();
std::thread::Builder::new()
.name("MechanicsServer".to_string())
.spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
rt.block_on(async move {
let listener = TcpListener::from_std(std_listener)?;
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let pool = server.pool();
let tokens = Arc::clone(&server.tokens);
tokio::task::spawn(async move {
let service = service_fn(move |req| {
handle_request(pool.clone(), Arc::clone(&tokens), req)
});
let _ = http1::Builder::new().serve_connection(io, service).await;
});
}
#[allow(unreachable_code)]
Ok::<_, std::io::Error>(())
})
})?;
Ok(())
}
#[cfg(feature = "https")]
pub fn run_tls(&self, bind_addr: SocketAddr, tls_config: TlsConfig) -> std::io::Result<()> {
let acceptor = tls_config.into_acceptor()?;
let std_listener = std::net::TcpListener::bind(bind_addr)?;
std_listener.set_nonblocking(true)?;
let server = self.clone();
std::thread::Builder::new()
.name("MechanicsServer-tls".to_string())
.spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
rt.block_on(async move {
let listener = TcpListener::from_std(std_listener)?;
loop {
let (stream, _) = listener.accept().await?;
let tls_stream = match acceptor.accept(stream).await {
Ok(s) => s,
Err(_) => continue,
};
let io = TokioIo::new(tls_stream);
let pool = server.pool();
let tokens = Arc::clone(&server.tokens);
tokio::task::spawn(async move {
let service = service_fn(move |req| {
let pool = pool.clone();
let tokens = Arc::clone(&tokens);
async move {
handle_request(pool, tokens, req).await.map(with_hsts)
}
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
#[allow(unreachable_code)]
Ok::<_, std::io::Error>(())
})
})?;
Ok(())
}
#[cfg(feature = "https")]
pub fn run_tls_with_h3(
&self,
bind_addr: SocketAddr,
tls_config: TlsConfig,
h3_config: Option<Http3ServerConfig>,
) -> std::io::Result<()> {
let Some(h3_config) = h3_config else {
return self.run_tls(bind_addr, tls_config);
};
let h3_bind = h3_config.bind_h3;
let h3_port = h3_bind.map(|addr| addr.port());
let h3_max_age_secs = h3_config.alt_svc_max_age_secs;
let (acceptor, h3_cert_chain, h3_private_key) =
tls_config.into_acceptor_and_h3_material()?;
let std_listener = std::net::TcpListener::bind(bind_addr)?;
std_listener.set_nonblocking(true)?;
let server = self.clone();
std::thread::Builder::new()
.name("MechanicsServer-tls-h3".to_string())
.spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
rt.block_on(async move {
let listener = TcpListener::from_std(std_listener)?;
let h3_handle = if h3_bind.is_some() {
let tokens = Arc::clone(&server.tokens);
let h3_service = tower_service_fn(move |req| {
handle_h3_request(Arc::clone(&tokens), req)
});
Some(
mechanics_http_server::Http3Server::new(h3_config)
.start(h3_service, h3_cert_chain, h3_private_key)
.map_err(std::io::Error::other)?,
)
} else {
None
};
let tcp_task = async move {
loop {
let (stream, _) = listener.accept().await?;
let tls_stream = match acceptor.accept(stream).await {
Ok(s) => s,
Err(_) => continue,
};
let io = TokioIo::new(tls_stream);
let pool = server.pool();
let tokens = Arc::clone(&server.tokens);
tokio::task::spawn(async move {
let service = service_fn(move |req| {
let pool = pool.clone();
let tokens = Arc::clone(&tokens);
async move {
handle_request(pool, tokens, req).await.map(|response| {
let response = with_hsts(response);
match h3_port {
Some(port) => {
with_alt_svc(response, port, h3_max_age_secs)
}
None => response,
}
})
}
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
#[allow(unreachable_code)]
Ok::<_, std::io::Error>(())
};
if let Some(mut h3_handle) = h3_handle {
tokio::select! {
tcp_result = tcp_task => {
h3_handle.shutdown();
tcp_result
}
h3_result = &mut h3_handle => {
h3_result.map_err(std::io::Error::other)
}
}
} else {
tcp_task.await
}
})
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::parse_bearer_token;
#[test]
fn parse_bearer_token_accepts_case_insensitive_scheme() {
assert_eq!(parse_bearer_token("Bearer abc"), Some("abc"));
assert_eq!(parse_bearer_token("bearer abc"), Some("abc"));
assert_eq!(parse_bearer_token("BEARER abc"), Some("abc"));
}
#[test]
fn parse_bearer_token_accepts_flexible_whitespace() {
assert_eq!(parse_bearer_token(" Bearer abc "), Some("abc"));
assert_eq!(parse_bearer_token("\tBearer\tabc\t"), Some("abc"));
}
#[test]
fn parse_bearer_token_rejects_invalid_values() {
assert_eq!(parse_bearer_token("Basic abc"), None);
assert_eq!(parse_bearer_token("Bearer"), None);
assert_eq!(parse_bearer_token("Bearer "), None);
assert_eq!(parse_bearer_token("Bearer abc def"), None);
}
}