use std::collections::HashSet;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use parking_lot::RwLock;
use http_body_util::{BodyExt, Full};
use hyper::body::{Bytes, Incoming};
use hyper::header::{AUTHORIZATION, CONTENT_TYPE, WWW_AUTHENTICATE};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use mechanics_core::job::MechanicsJob;
use tokio::net::TcpListener;
use mechanics_core::MechanicsPool;
pub use mechanics_core::MechanicsPoolConfig;
type HttpResponse = Response<Full<Bytes>>;
enum ApiError {
NotFound,
Unauthorized,
InvalidType,
InvalidRequest,
Pool(String),
Internal,
}
impl ApiError {
fn to_response(&self) -> HttpResponse {
let (status, message) = match self {
Self::NotFound => (StatusCode::NOT_FOUND, "Not found".to_string()),
Self::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()),
Self::InvalidType => (StatusCode::BAD_REQUEST, "Invalid type".to_string()),
Self::InvalidRequest => (StatusCode::BAD_REQUEST, "Invalid request".to_string()),
Self::Pool(err) => (StatusCode::BAD_REQUEST, err.clone()),
Self::Internal => (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error".to_string(),
),
};
let mut response = json_response(status, &serde_json::json!({ "error": message }));
if matches!(self, Self::Unauthorized) {
response.headers_mut().insert(
WWW_AUTHENTICATE,
hyper::header::HeaderValue::from_static("Bearer"),
);
}
response
}
}
fn json_response(status: StatusCode, value: &serde_json::Value) -> HttpResponse {
let body = serde_json::to_vec(value).unwrap_or_else(|_| b"{}".to_vec());
let mut response = Response::new(Full::new(Bytes::from(body)));
*response.status_mut() = status;
response.headers_mut().insert(
CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
);
response
}
fn has_json_content_type(req: &Request<Incoming>) -> bool {
req.headers()
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(|value| {
value
.split(';')
.next()
.is_some_and(|mime| mime.trim().eq_ignore_ascii_case("application/json"))
})
.unwrap_or(false)
}
fn parse_bearer_token(header_value: &str) -> Option<&str> {
let mut parts = header_value.split_whitespace();
let scheme = parts.next()?;
if !scheme.eq_ignore_ascii_case("bearer") {
return None;
}
let token = parts.next()?;
if token.is_empty() || parts.next().is_some() {
return None;
}
Some(token)
}
fn bearer_token(req: &Request<Incoming>) -> Option<&str> {
req.headers()
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.and_then(parse_bearer_token)
}
fn is_authorized(tokens: &RwLock<HashSet<String>>, req: &Request<Incoming>) -> bool {
let Some(token) = bearer_token(req) else {
return false;
};
tokens.read().contains(token)
}
async fn parse_json_job(req: Request<Incoming>) -> Result<MechanicsJob, ApiError> {
if !has_json_content_type(&req) {
return Err(ApiError::InvalidType);
}
let body = req
.into_body()
.collect()
.await
.map_err(|_| ApiError::InvalidRequest)?
.to_bytes();
serde_json::from_slice(&body).map_err(|_| ApiError::InvalidRequest)
}
async fn execute_job(
pool: Arc<MechanicsPool>,
job: MechanicsJob,
) -> Result<serde_json::Value, ApiError> {
let task = tokio::task::spawn_blocking(move || pool.run(job));
let run_result = task.await.map_err(|_| ApiError::Internal)?;
run_result.map_err(|error| ApiError::Pool(error.to_string()))
}
async fn handle_request(
pool: Arc<MechanicsPool>,
tokens: Arc<RwLock<HashSet<String>>>,
req: Request<Incoming>,
) -> Result<HttpResponse, Infallible> {
if req.method() != Method::POST || req.uri().path() != "/api/v1/mechanics" {
return Ok(ApiError::NotFound.to_response());
}
if !is_authorized(&tokens, &req) {
return Ok(ApiError::Unauthorized.to_response());
}
let job = match parse_json_job(req).await {
Ok(job) => job,
Err(error) => return Ok(error.to_response()),
};
match execute_job(pool, job).await {
Ok(result) => Ok(json_response(StatusCode::OK, &result)),
Err(error) => Ok(error.to_response()),
}
}
#[derive(Clone)]
pub struct MechanicsServer {
pool: Arc<MechanicsPool>,
tokens: Arc<RwLock<HashSet<String>>>,
}
impl MechanicsServer {
pub fn new(config: MechanicsPoolConfig) -> std::io::Result<Self> {
let pool = Arc::new(MechanicsPool::new(config).map_err(std::io::Error::other)?);
Ok(Self {
pool,
tokens: Arc::new(RwLock::default()),
})
}
pub fn add_token(&self, token: String) {
let token = token.trim();
if token.is_empty() {
return;
}
self.tokens.write().insert(token.to_string());
}
pub(crate) fn pool(&self) -> Arc<MechanicsPool> {
Arc::clone(&self.pool)
}
pub fn run(&self, bind_addr: SocketAddr) -> std::io::Result<()> {
let std_listener = std::net::TcpListener::bind(bind_addr)?;
std_listener.set_nonblocking(true)?;
let server = self.clone();
std::thread::Builder::new()
.name("MechanicsServer".to_string())
.spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
rt.block_on(async move {
let listener = TcpListener::from_std(std_listener)?;
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let pool = server.pool();
let tokens = Arc::clone(&server.tokens);
tokio::task::spawn(async move {
let service = service_fn(move |req| {
handle_request(pool.clone(), Arc::clone(&tokens), req)
});
if let Err(err) =
http1::Builder::new().serve_connection(io, service).await
{
eprintln!("Error serving connection: {err:?}");
}
});
}
#[allow(unreachable_code)]
Ok::<_, std::io::Error>(())
})
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::parse_bearer_token;
#[test]
fn parse_bearer_token_accepts_case_insensitive_scheme() {
assert_eq!(parse_bearer_token("Bearer abc"), Some("abc"));
assert_eq!(parse_bearer_token("bearer abc"), Some("abc"));
assert_eq!(parse_bearer_token("BEARER abc"), Some("abc"));
}
#[test]
fn parse_bearer_token_accepts_flexible_whitespace() {
assert_eq!(parse_bearer_token(" Bearer abc "), Some("abc"));
assert_eq!(parse_bearer_token("\tBearer\tabc\t"), Some("abc"));
}
#[test]
fn parse_bearer_token_rejects_invalid_values() {
assert_eq!(parse_bearer_token("Basic abc"), None);
assert_eq!(parse_bearer_token("Bearer"), None);
assert_eq!(parse_bearer_token("Bearer "), None);
assert_eq!(parse_bearer_token("Bearer abc def"), None);
}
}