use std::time::Duration;
use axum::body::{to_bytes, Body};
use axum::http::header::{HeaderValue, CONTENT_TYPE};
use axum::http::{Request, Response, StatusCode, Uri};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use bytes::Bytes;
use prometheus_client::encoding::text::encode;
use tracing::warn;
use utoipa::OpenApi;
use crate::errors::ServiceError;
use crate::metrics::{CodeLabel, Metrics, DURATION_LABEL, REGISTRY};
use crate::tasks::{TaskCode, TaskManager};
const SERVER_INFO: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
const RESPONSE_DEFAULT: &[u8] = b"MOSEC service";
const RESPONSE_EMPTY: &[u8] = b"no data provided";
const RESPONSE_TOO_LARGE: &[u8] = b"request body is too large";
const RESPONSE_SHUTDOWN: &[u8] = b"gracefully shutting down";
const DEFAULT_RESPONSE_MIME: &str = "application/json";
const DEFAULT_MAX_REQUEST_SIZE: usize = 10 * 1024 * 1024;
fn build_response(status: StatusCode, content: Bytes) -> Response<Body> {
Response::builder()
.status(status)
.header("server", HeaderValue::from_static(SERVER_INFO))
.body(Body::from(content))
.unwrap()
}
#[utoipa::path(
get,
path = "/",
responses(
(
status = StatusCode::OK,
description = "Root path, can be used for liveness health check",
body = String,
),
(
status = StatusCode::SERVICE_UNAVAILABLE,
description = "SERVICE_UNAVAILABLE",
body = String,
),
),
)]
pub(crate) async fn index() -> Response<Body> {
let task_manager = TaskManager::global();
if task_manager.is_shutdown() {
build_response(
StatusCode::SERVICE_UNAVAILABLE,
Bytes::from_static(RESPONSE_SHUTDOWN),
)
} else {
build_response(StatusCode::OK, Bytes::from_static(RESPONSE_DEFAULT))
}
}
#[utoipa::path(
get,
path = "/metrics",
responses(
(status = StatusCode::OK, description = "Get metrics", body = String),
),
)]
pub(crate) async fn metrics() -> Response<Body> {
let mut encoded = String::new();
let registry = REGISTRY.get().unwrap();
encode(&mut encoded, registry).unwrap();
build_response(StatusCode::OK, Bytes::from(encoded))
}
#[utoipa::path(
post,
path = "/openapi/reserved/inference",
responses(
(status = StatusCode::OK, description = "Inference"),
(status = StatusCode::BAD_REQUEST, description = "BAD_REQUEST"),
(status = StatusCode::SERVICE_UNAVAILABLE, description = "SERVICE_UNAVAILABLE"),
(status = StatusCode::UNPROCESSABLE_ENTITY, description = "UNPROCESSABLE_ENTITY"),
(status = StatusCode::REQUEST_TIMEOUT, description = "REQUEST_TIMEOUT"),
(status = StatusCode::INTERNAL_SERVER_ERROR, description = "INTERNAL_SERVER_ERROR"),
(status = StatusCode::TOO_MANY_REQUESTS, description = "TOO_MANY_REQUESTS"),
),
)]
pub(crate) async fn inference(uri: Uri, req: Request<Body>) -> Response<Body> {
let task_manager = TaskManager::global();
let endpoint = uri.path();
let mime = match task_manager.get_mime_type(endpoint) {
Some(mime) => mime.as_str(),
None => DEFAULT_RESPONSE_MIME,
};
if task_manager.is_shutdown() {
return build_response(
StatusCode::SERVICE_UNAVAILABLE,
Bytes::from_static(RESPONSE_SHUTDOWN),
);
}
let data = match to_bytes(req.into_body(), DEFAULT_MAX_REQUEST_SIZE).await {
Ok(data) => data,
Err(err) => {
warn!(?err, "failed to read request body (too large)");
return build_response(
StatusCode::PAYLOAD_TOO_LARGE,
Bytes::from_static(RESPONSE_TOO_LARGE),
);
}
};
if data.is_empty() {
return build_response(StatusCode::OK, Bytes::from_static(RESPONSE_EMPTY));
}
let (status, content);
let metrics = Metrics::global();
metrics.remaining_task.inc();
match task_manager.submit_task(data, endpoint).await {
Ok(task) => {
content = task.data;
status = match task.code {
TaskCode::Normal => {
metrics
.duration
.get_or_create(
DURATION_LABEL
.get()
.expect("DURATION_LABEL is not initialized"),
)
.observe(task.create_at.elapsed().as_secs_f64());
StatusCode::OK
}
TaskCode::BadRequestError => StatusCode::BAD_REQUEST,
TaskCode::ValidationError => StatusCode::UNPROCESSABLE_ENTITY,
TaskCode::TimeoutError => StatusCode::REQUEST_TIMEOUT,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
Err(err) => {
content = Bytes::from(err.to_string());
status = match err {
ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
}
}
metrics.remaining_task.dec();
metrics
.throughput
.get_or_create(&CodeLabel {
code: status.as_u16(),
endpoint: endpoint.to_string(),
})
.inc();
let mut resp = build_response(status, content);
if status == StatusCode::OK {
resp.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_str(mime).unwrap());
}
resp
}
#[utoipa::path(
post,
path = "/openapi/reserved/inference_sse",
responses(
(status = StatusCode::OK, description = "Inference"),
(status = StatusCode::BAD_REQUEST, description = "BAD_REQUEST"),
(status = StatusCode::SERVICE_UNAVAILABLE, description = "SERVICE_UNAVAILABLE"),
(status = StatusCode::UNPROCESSABLE_ENTITY, description = "UNPROCESSABLE_ENTITY"),
(status = StatusCode::REQUEST_TIMEOUT, description = "REQUEST_TIMEOUT"),
(status = StatusCode::INTERNAL_SERVER_ERROR, description = "INTERNAL_SERVER_ERROR"),
(status = StatusCode::TOO_MANY_REQUESTS, description = "TOO_MANY_REQUESTS"),
),
)]
pub(crate) async fn sse_inference(uri: Uri, req: Request<Body>) -> Response<Body> {
let task_manager = TaskManager::global();
let endpoint = uri.path();
if task_manager.is_shutdown() {
return (
StatusCode::SERVICE_UNAVAILABLE,
Bytes::from_static(RESPONSE_SHUTDOWN),
)
.into_response();
}
let data = match to_bytes(req.into_body(), DEFAULT_MAX_REQUEST_SIZE).await {
Ok(data) => data,
Err(err) => {
warn!(?err, "failed to read request body (too large)");
return build_response(
StatusCode::PAYLOAD_TOO_LARGE,
Bytes::from_static(RESPONSE_TOO_LARGE),
);
}
};
if data.is_empty() {
return (StatusCode::OK, Bytes::from_static(RESPONSE_EMPTY)).into_response();
}
let metrics = Metrics::global();
match task_manager.submit_sse_task(data, endpoint).await {
Ok(mut rx) => {
let stream = async_stream::stream! {
while let Some((msg, code)) = rx.recv().await {
yield match code {
TaskCode::Normal => {
Ok(Event::default().data(String::from_utf8_lossy(&msg)))
},
TaskCode::BadRequestError | TaskCode::InternalError | TaskCode::ValidationError | TaskCode::TimeoutError => {
Ok(Event::default().event("error").data(
format!("{}: {}", ServiceError::SSEError(code), String::from_utf8_lossy(&msg))),
)
}
_ => {
warn!(?code, ?msg, "unexpected error in SSE");
Err(ServiceError::SSEError(code))
}
}
}
};
Sse::new(stream)
.keep_alive(KeepAlive::new().interval(Duration::from_secs(3)))
.into_response()
}
Err(err) => {
let content = Bytes::from(err.to_string());
let status = match err {
ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
metrics
.throughput
.get_or_create(&CodeLabel {
code: status.as_u16(),
endpoint: endpoint.to_string(),
})
.inc();
(status, content).into_response()
}
}
}
#[derive(OpenApi)]
#[openapi(paths(index, metrics, inference, sse_inference))]
pub(crate) struct RustAPIDoc;