use crate::app_state::{AppState, SharedAppState};
use crate::cli::CommandLineArgs;
use crate::error::ActiveStorageError;
use crate::filter_pipeline;
use crate::metrics::{metrics_handler, track_metrics};
use crate::models::{self, CBORResponse};
use crate::operation;
use crate::operations;
use crate::validated_json::ValidatedJson;
use axum::middleware;
use axum::{
extract::{Path, State},
headers::authorization::{Authorization, Basic},
http::header,
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
Router, TypedHeader,
};
use bytes::Bytes;
use serde_cbor;
use tower::Layer;
use tower::ServiceBuilder;
use tower_http::normalize_path::NormalizePathLayer;
use tower_http::trace::TraceLayer;
use tracing::debug_span;
impl IntoResponse for models::Response {
fn into_response(self) -> Response {
(
StatusCode::OK,
[(&header::CONTENT_TYPE, "application/cbor")],
serde_cbor::to_vec(&CBORResponse::new(&self))
.map_err(|e| log::error!("Failed to serialize CBOR: {e}"))
.unwrap(),
)
.into_response()
}
}
pub fn init(args: &CommandLineArgs) {
if args.use_rayon {
rayon::ThreadPoolBuilder::new()
.num_threads(num_cpus::get() - 1)
.build_global()
.expect("Failed to build Rayon thread pool");
};
}
fn router(args: &CommandLineArgs) -> Router {
fn v2(state: SharedAppState) -> Router {
Router::new()
.route("/count", post(operation_handler::<operations::Count>))
.route("/max", post(operation_handler::<operations::Max>))
.route("/min", post(operation_handler::<operations::Min>))
.route("/select", post(operation_handler::<operations::Select>))
.route("/sum", post(operation_handler::<operations::Sum>))
.route("/:operation", post(unknown_operation_handler))
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
.with_state(state)
}
let state = SharedAppState::new(AppState::new(args));
Router::new()
.route("/.well-known/reductionist-schema", get(schema))
.route("/metrics", get(metrics_handler))
.nest("/v2", v2(state))
.route_layer(middleware::from_fn(track_metrics))
}
pub type Service = tower_http::normalize_path::NormalizePath<Router>;
pub fn service(args: &CommandLineArgs) -> Service {
NormalizePathLayer::trim_trailing_slash().layer(router(args))
}
async fn schema() -> &'static str {
"Hello, world!"
}
async fn operation_handler<T: operation::Operation>(
State(state): State<SharedAppState>,
auth: Option<TypedHeader<Authorization<Basic>>>,
ValidatedJson(request_data): ValidatedJson<models::RequestData>,
) -> Result<models::Response, ActiveStorageError> {
let memory = request_data.size.unwrap_or(0);
let mem_permits = state.resource_manager.memory(memory).await?;
let data = state
.chunk_store
.get(&auth, &request_data, &state.resource_manager, mem_permits)
.await?;
if state.args.use_rayon {
tokio_rayon::spawn(move || operation::<T>(request_data, data)).await
} else {
let _task_permit = state.resource_manager.task().await?;
operation::<T>(request_data, data)
}
}
fn operation<T: operation::Operation>(
request_data: models::RequestData,
data: Bytes,
) -> Result<models::Response, ActiveStorageError> {
let ptr = data.as_ptr();
let data = filter_pipeline::filter_pipeline(&request_data, data)?;
if request_data.compression.is_some() || request_data.size.is_none() {
models::validate_raw_size(data.len(), request_data.dtype, &request_data.shape)?;
}
if request_data.compression.is_none() && request_data.filters.is_none() {
assert_eq!(ptr, data.as_ptr());
}
let ptr = data.as_ptr();
let vec: Vec<u8> = data.into();
assert_eq!(ptr, vec.as_ptr());
debug_span!("operation").in_scope(|| T::execute(&request_data, vec))
}
async fn unknown_operation_handler(Path(operation): Path<String>) -> ActiveStorageError {
ActiveStorageError::UnsupportedOperation { operation }
}