use rocket::{catch, data::ByteUnit, get, http::{ContentType, Status}, post, request::{FromRequest, Outcome}, response::{status, Responder}, serde::json::Json, Data, Request, Response, Shutdown, State};
use serde_json::{json, Value};
use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::time::Duration;
use tokio::{
io::AsyncReadExt,
spawn,
sync::{mpsc, Notify, Semaphore},
task::JoinHandle,
time::timeout,
};
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use async_stream::stream;
use async_trait::async_trait;
use rocket::data::{Limits, ToByteUnit};
use rocket::mtls::Certificate;
use crate::{
manager::{Base, SharedState, Config},
multibus::MultiBus,
};
pub struct CustomResponse(Response<'static>);
impl<'r> Responder<'r, 'static> for CustomResponse {
fn respond_to(self, _: &'r Request<'_>) -> rocket::response::Result<'static> {
Ok(self.0)
}
}
pub struct Authentication;
#[async_trait]
impl<'r> FromRequest<'r> for Authentication {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let api_key: &State<Config> = request.guard::<&State<Config>>().await.unwrap();
let auth_header = request.headers().get_one("Authorization");
if let Some(auth_header) = auth_header {
if auth_header == api_key.inner().api_key {
Outcome::Success(Authentication)
} else {
Outcome::Error((Status::Unauthorized, ()))
}
} else {
Outcome::Error((Status::Unauthorized, ()))
}
}
}
#[catch(401)]
pub fn unauthorized(_req: &Request) -> status::Custom<Json<Value>> {
let response = json!({
"message": "Unauthorized: Access is denied due to invalid credentials.".to_string(),
});
status::Custom(Status::Unauthorized, Json(response))
}
#[catch(403)]
pub fn invalid_cn(_req: &Request) -> status::Custom<Json<Value>> {
let response = json!({
"message": "Forbidden: Access is denied due to invalid CN name.".to_string(),
});
status::Custom(Status::Forbidden, Json(response))
}
/// A wrapper for the `Content-Length` header value.
pub struct LowerBound(usize);
#[async_trait]
impl<'r> FromRequest<'r> for LowerBound {
type Error = ();
/// Extracts the `Lower-Bound` header value from the request.
///
/// # Arguments
///
/// * `request` - The incoming HTTP request.
///
/// # Returns
///
/// * `Outcome::Success(LowerBound)` if the header is present and valid.
/// * `Outcome::Error(Status::BadRequest)` if the header is present but the value is invalid.
/// * `Outcome::Error(Status::LengthRequired)` if the header is missing.
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
if let Some(value) = request.headers().get_one("Lower-Bound") {
match value.parse::<usize>() {
Ok(length) => Outcome::Success(LowerBound(length)),
Err(_) => Outcome::Error((Status::BadRequest, ())),
}
} else {
Outcome::Error((Status::LengthRequired, ()))
}
}
}
pub struct CNCheck;
#[async_trait]
impl<'r> FromRequest<'r> for CNCheck {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let cert: Certificate = request.guard::<Certificate<'r>>().await.unwrap();
let allowed_names_wrapped = request.guard::<&State<HashSet<String>>>().await;
if allowed_names_wrapped.is_error() {
return Outcome::Success(CNCheck);
}
let allowed_names = allowed_names_wrapped.unwrap();
if let Some(cert_name) = cert.subject().common_name() {
if allowed_names.contains(cert_name) {
return Outcome::Success(CNCheck);
}
}
Outcome::Error((Status::Forbidden, ()))
}
}
async fn handle_common_ending(processor: JoinHandle<Result<String, Box<dyn Error + Send + Sync>>>, manager: &State<Arc<Notify>>, str_handler: String) -> Value {
tokio::select! {
result = processor => {
println!("Processed a request to handler {}", str_handler);
match result {
Ok(result) => {
match result {
Ok(output) => json!({"status": "Ok", "message": output}),
Err(e) => json!({"status": "Error", "message": e.to_string()}),
}
}
Err(e) => {
json!({"status": "Error", "message": e.to_string()})
}
}
},
_ = manager.notified() => {
println!("Service is stopped forcefully, request aborted");
json!({"message": "Service closed forcefully"})
}
}
}
/// Handles HTTP requests to shut down the server.
///
/// # Arguments
/// * `shutdown` - The `Shutdown` handle to terminate the server.
/// * `manager` - A state holding the shutdown notifier.
///
/// # Returns
/// * A JSON response indicating the server shutdown status.
#[post("/shutdown")]
pub async fn shutdown_from_http(shutdown: Shutdown, manager: &State<Arc<Notify>>, has_been_called: &State<AtomicBool>) -> Json<Value> {
shutdown.notify();
has_been_called.inner().store(true, Ordering::SeqCst);
manager.inner().notify_waiters();
Json(json!({"status": "Shutting down"}))
}
/// Handles HTTP requests to process specific handlers.
///
/// # Arguments
/// * `handler_name` - The name of the handler to process the request.
/// * `body` - The request body containing the data to be processed.
/// * `permits` - The semaphore controlling concurrency (only one request at a time).
/// * `instance` - The registered handlers mapped by name.
/// * `manager` - The shutdown notifier.
/// * `communication_line` - The communication line (MultiBus) for message exchange.
/// * `shared_state` - The variable used to access the shared state
///
/// # Returns
/// * A JSON response indicating the result of the handler's processing.
#[post("/<handler_name>", data = "<body>")]
pub async fn process_request(_cert: CNCheck, _auth: Authentication, handler_name: String, body: Data<'_>, permits: &State<Arc<Semaphore>>, instance: &State<HashMap<String, Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>>>>, manager: &State<Arc<Notify>>, communication_line: &State<Arc<MultiBus>>, shared_state: &State<Arc<SharedState>>, request_max_per_handler: &State<Arc<HashMap<String, Arc<Semaphore>>>>) -> Value {
let mut body_string = String::new();
body.open(ByteUnit::max_value())
.read_to_string(&mut body_string)
.await
.expect("Failed to read request body");
let str_handler = handler_name.to_string();
let str_handler_copy = str_handler.clone();
let permits_clone = permits.inner().clone();
let permits_per_handler_clone = request_max_per_handler.inner().clone().get(&handler_name).unwrap().clone();
let communication_line_clone = communication_line.inner().clone();
let shared_state_clone = shared_state.inner().clone();
let instance_to_run: Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>> = Arc::clone(instance.inner().get(&str_handler_copy).unwrap());
let processor = tokio::spawn(async move {
let permit = permits_clone.clone().acquire_owned().await.unwrap();
let permit_handler = permits_per_handler_clone.clone().acquire_owned().await.unwrap();
let result = instance_to_run().run(str_handler_copy, body_string, communication_line_clone.clone(), shared_state_clone.clone()).await;
drop(permit);
drop(permit_handler);
result
});
handle_common_ending(processor, &manager, str_handler).await
}
/// Handles file download requests.
///
/// # Arguments
///
/// * `file_id` - The unique identifier of the file to download.
/// * `permits` - A semaphore to manage concurrent downloads, ensuring controlled access.
/// * `instance` - A map of handler instances that are used to process the file retrieval logic.
/// * `manager` - A shutdown notifier to monitor and handle service shutdowns.
/// * `communication_line` - The communication line (`MultiBus`) for message exchange between components.
/// * `shared_state` - The shared state (`SharedState`) for managing application-wide state or data.
///
/// # Returns
///
/// * [`Ok(CustomResponse)`](CustomResponse) - On success, returns a response containing the file data as a binary stream.
/// * [`Err(status::Custom<String>)`](status::Custom) - On failure, returns an appropriate error status and message.
///
/// ## Failure Cases:
/// - If the file is not found or fails to process, returns a `500 Internal Server Error`.
/// - If the service is interrupted or forcefully stopped, returns a `503 Service Unavailable`.
/// - If there is a general error during the process, returns a `400 Bad Request`.
#[get("/download/<file_id>")]
pub async fn process_download(_cert: CNCheck, _auth: Authentication, file_id: String, permits: &State<Arc<Semaphore>>, instance: &State<HashMap<String, Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>>>>, manager: &State<Arc<Notify>>, communication_line: &State<Arc<MultiBus>>, shared_state: &State<Arc<SharedState>>, request_max_per_handler: &State<Arc<HashMap<String, Arc<Semaphore>>>>) -> Result<CustomResponse, status::Custom<String>> {
let file_id_clone = file_id.clone();
let str_handler: String = "download".to_string();
let str_handler_copy = str_handler.clone();
let permits_clone = permits.inner().clone();
let permits_per_handler_clone = request_max_per_handler.inner().clone().get(&str_handler_copy).unwrap().clone();
let communication_line_clone = communication_line.inner().clone();
let shared_state_clone = shared_state.inner().clone();
let instance_to_run: Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>> = Arc::clone(instance.inner().get(&str_handler_copy).unwrap());
let processor = tokio::spawn(async move {
let permit = permits_clone.clone().acquire_owned().await.unwrap();
let permit_handler = permits_per_handler_clone.clone().acquire_owned().await.unwrap();
let result = instance_to_run().run_file(str_handler_copy, file_id_clone, communication_line_clone.clone(), shared_state_clone.clone()).await;
drop(permit);
drop(permit_handler);
result
});
tokio::select! {
result = processor => {
println!("Processed a request to handler {}", str_handler);
match result {
Ok(result) => {
match result {
Ok((file, size)) => {
Ok(CustomResponse(
Response::build()
.header(ContentType::Binary)
.header(rocket::http::Header::new("Content-Length", size.to_string()))
.streamed_body(file)
.finalize()
)
)
},
Err(e) => {
let error_body = json!({
"status": "Error",
"message": e.to_string()
}).to_string();
Err(
status::Custom(
Status::InternalServerError,
error_body,
)
)
},
}
}
Err(e) => {
let error_body = json!({
"status": "Error",
"message": e.to_string()
}).to_string();
Err(
status::Custom(
Status::BadRequest,
error_body,
)
)
}
}
},
_ = manager.notified() => {
println!("Service is stopped forcefully, request aborted");
let error_body = json!({
"status": "Error",
"message": "Service closed forcefully"
}).to_string();
Err(
status::Custom(
Status::ServiceUnavailable,
error_body,
)
)
}
}
}
/// Handles file upload requests.
///
/// # Arguments
///
/// * `file_name` - The name of the file being uploaded.
/// * `body` - The request body containing the file data in a stream.
/// * `content_length` - The `Content-Length` header value, representing the size of the uploaded file.
/// * `permits` - A semaphore to manage concurrent uploads, ensuring controlled access.
/// * `instance` - A map of handler instances that are used to process the upload logic.
/// * `manager` - A shutdown notifier to monitor and handle service shutdowns.
/// * `communication_line` - The communication line (`MultiBus`) for message exchange between components.
/// * `shared_state` - The shared state (`SharedState`) for managing application-wide state or data.
/// * `has_been_called` - A flag indicating whether the service has been called or interrupted.
///
/// # Returns
///
/// A JSON response indicating the processing result of the uploaded file.
///
/// - On success, returns a JSON object with a `"status"` of `"Ok"` and a corresponding `"message"`.
/// - On failure, returns a JSON object with a `"status"` of `"Error"` and an appropriate error `"message"`.
#[post("/upload/<file_name>", data = "<body>")]
pub async fn process_upload<'a>(_cert: CNCheck, _auth: Authentication, file_name: String, body: Data<'a>, lower_bound: Option<LowerBound>, permits: &State<Arc<Semaphore>>, instance: &State<HashMap<String, Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>>>>, manager: &State<Arc<Notify>>, communication_line: &State<Arc<MultiBus>>, shared_state: &State<Arc<SharedState>>, has_been_called: &State<AtomicBool>, limits: &Limits, request_max_per_handler: &State<Arc<HashMap<String, Arc<Semaphore>>>>) -> Value {
let file_name_cln = file_name.clone();
let lower_bound: usize = lower_bound.expect("Lower-Bound header should be defined").0;
let str_handler = "upload".to_string();
let str_handler_copy = str_handler.clone();
let permits_clone = permits.inner().clone();
let permits_per_handler_clone = request_max_per_handler.inner().clone().get(&str_handler_copy).unwrap().clone();
let communication_line_clone = communication_line.inner().clone();
let shared_state_clone = shared_state.inner().clone();
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
let instance_to_run: Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>> = Arc::clone(instance.inner().get(&str_handler_copy).unwrap());
let processor = spawn(async move {
let permit = permits_clone.clone().acquire_owned().await.unwrap();
let permit_handler = permits_per_handler_clone.clone().acquire_owned().await.unwrap();
let mut stream_from_channel = ReceiverStream::new(rx);
let result = instance_to_run().run_stream(
str_handler_copy,
Box::pin(stream! {
while let Some(chunk) = stream_from_channel.next().await {
yield chunk;
}
}),
file_name_cln,
communication_line_clone.clone(),
shared_state_clone.clone()).await;
drop(permit);
drop(permit_handler);
result
});
let limit = limits.get("stream").unwrap_or(1.megabytes());
let mut stream = body.open(limit);
let mut buffer = vec![0; 16 * 1024];
let mut total_size = 0;
loop {
match timeout(Duration::from_secs(20), stream.read(&mut buffer)).await {
Ok(Ok(0)) => {
if total_size < lower_bound {
drop(tx);
return json!({"status": "Error", "message": "Length of the stream does not match the Lower-Bound header"});
}
break;
}
Ok(Ok(n)) => {
total_size += n;
if has_been_called.inner().load(Ordering::SeqCst) {
break;
}
if tx.send(buffer[..n].to_vec()).await.is_err() {
break;
}
}
Ok(Err(e)) => {
eprintln!("Error reading stream: {}", e);
break;
}
Err(_) => {
eprintln!("Stream read timed out after 20 seconds");
drop(tx);
return json!({"status": "Error", "message": "Stream read timed out"});
}
}
}
drop(tx);
handle_common_ending(processor, &manager, str_handler).await
}
/// Handles requests to retrieve metadata for a specific file.
///
/// # Arguments
/// * `file_id` - The unique identifier of the file whose metadata is being requested.
/// * `permits` - A semaphore to manage concurrent accesses, ensuring controlled concurrency.
/// * `instance` - A map containing registered handler instances, keyed by handler names.
/// * `manager` - A shutdown notifier to monitor and handle service shutdowns.
/// * `communication_line` - The communication line (`MultiBus`) for inter-component communication.
/// * `shared_state` - The shared state (`SharedState`) for managing application-wide state or data.
///
/// # Returns
/// * A JSON response containing the metadata if the operation succeeds.
/// * A JSON response with an error message if the operation fails.
#[get("/metadata/<file_id>")]
pub async fn process_metadata(_cert: CNCheck, _auth: Authentication, file_id: String, permits: &State<Arc<Semaphore>>, instance: &State<HashMap<String, Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>>>>, manager: &State<Arc<Notify>>, communication_line: &State<Arc<MultiBus>>, shared_state: &State<Arc<SharedState>>, request_max_per_handler: &State<Arc<HashMap<String, Arc<Semaphore>>>>) -> Value {
let file_id_clone = file_id.clone();
let str_handler: String = "metadata".to_string();
let str_handler_copy = str_handler.clone();
let permits_clone = permits.inner().clone();
let permits_per_handler_clone = request_max_per_handler.inner().clone().get(&str_handler_copy).unwrap().clone();
let communication_line_clone = communication_line.inner().clone();
let shared_state_clone = shared_state.inner().clone();
let instance_to_run: Arc<Box<dyn Fn () -> Box<dyn Base + Send + Sync> + Send + Sync>> = Arc::clone(instance.inner().get(&str_handler_copy).unwrap());
let processor = spawn(async move {
let permit = permits_clone.clone().acquire_owned().await.unwrap();
let permit_handler = permits_per_handler_clone.clone().acquire_owned().await.unwrap();
let result = instance_to_run().run_metadata(str_handler_copy, file_id_clone, communication_line_clone.clone(), shared_state_clone.clone()).await;
drop(permit);
drop(permit_handler);
result
});
handle_common_ending(processor, &manager, str_handler).await
}