use crate::security::AccessForbiddenError;
use crate::{
persistence::MappingError, persistence::RepositoryError, security::AuthorizationError,
server::error_codes,
};
use actix_web::{HttpRequest, HttpResponse, Responder, ResponseError, http::StatusCode};
use actix_web::{
HttpResponseBuilder, Result,
body::BoxBody,
dev::ServiceResponse,
error::{ErrorInternalServerError, JsonPayloadError},
http::header,
middleware::ErrorHandlerResponse,
};
use serde::{Deserialize, Serialize, Serializer};
use thiserror::Error;
use utoipa::IntoResponses;
use super::error_codes::{GENERIC_DB_ERROR, NOT_FOUND_ERROR};
pub static API_RESPONSE_CONTENT_TYPE: &str = "application/json; charset=utf-8";
pub static API_REQUEST_ID_HEADER_KEY: &str = "restrepo-request-id";
pub static API_RESPONSE_HEADER_KEY: &str = "restrepo-api-response";
#[derive(Clone, Error, Debug, Deserialize, PartialEq, Eq, Serialize, IntoResponses)]
#[serde(tag = "type")]
#[serde(rename_all(deserialize = "camelCase", serialize = "camelCase"))]
pub enum ApiErrorV1 {
#[error("{code}: Malformed data: {detail}")]
#[response(status = StatusCode::BAD_REQUEST)]
MalformedDataError { detail: String, code: String },
#[error("{code}: Could not find requested resource.")]
#[response(status = StatusCode::NOT_FOUND)]
NotFoundError { code: String },
#[error("{code}: Operation was not permitted: {detail}")]
#[response(status = StatusCode::FORBIDDEN)]
OperationForbiddenError {
code: String,
#[serde(skip_serializing)]
detail: String,
},
#[error("{code}: Request could not be authorised: {detail}")]
#[response(status = StatusCode::UNAUTHORIZED)]
UnauthorizedError {
code: String,
#[serde(skip_serializing)]
detail: String,
},
#[error("{code}: An unexpected error occurred: {detail}")]
#[response(status = StatusCode::INTERNAL_SERVER_ERROR)]
InternalError {
code: String,
#[serde(skip_serializing)]
detail: String,
},
}
impl From<StatusCode> for ApiErrorV1 {
fn from(status: StatusCode) -> Self {
match status {
StatusCode::BAD_REQUEST => ApiErrorV1::MalformedDataError {
detail: "Bad request".to_string(),
code: error_codes::DATA_FORMAT_ERROR.to_string(),
},
StatusCode::NOT_FOUND => ApiErrorV1::NotFoundError {
code: NOT_FOUND_ERROR.to_owned(),
},
StatusCode::FORBIDDEN => ApiErrorV1::OperationForbiddenError {
code: error_codes::AUTHORIZATION_ERROR.to_string(),
detail: String::default(),
},
StatusCode::UNAUTHORIZED => ApiErrorV1::UnauthorizedError {
code: error_codes::AUTHORIZATION_ERROR.to_string(),
detail: String::default(),
},
_ => ApiErrorV1::InternalError {
code: error_codes::GENERIC_ERROR.to_owned(),
detail: String::default(),
},
}
}
}
impl From<AccessForbiddenError> for ApiErrorV1 {
fn from(error: AccessForbiddenError) -> Self {
ApiErrorV1::OperationForbiddenError {
code: error_codes::ACCESS_FORBIDDEN_ERROR.to_string(),
detail: error.to_string(),
}
}
}
impl From<AuthorizationError> for ApiErrorV1 {
fn from(error: AuthorizationError) -> Self {
ApiErrorV1::UnauthorizedError {
code: error_codes::AUTHORIZATION_ERROR.to_string(),
detail: error.to_string(),
}
}
}
impl From<anyhow::Error> for ApiErrorV1 {
fn from(e: anyhow::Error) -> Self {
ApiErrorV1::InternalError {
code: error_codes::SERVICE_ERROR.to_string(),
detail: e.to_string(),
}
}
}
impl From<RepositoryError> for ApiErrorV1 {
fn from(error: RepositoryError) -> Self {
match error {
RepositoryError::DatabaseLookupError() => ApiErrorV1::NotFoundError {
code: NOT_FOUND_ERROR.to_owned(),
},
_ => ApiErrorV1::InternalError {
code: GENERIC_DB_ERROR.to_owned(),
detail: String::default(),
},
}
}
}
impl From<MappingError> for ApiErrorV1 {
fn from(error: MappingError) -> Self {
ApiErrorV1::MalformedDataError {
detail: error.to_string(),
code: error_codes::CONTOLLER_ERROR.to_string(),
}
}
}
impl From<JsonPayloadError> for ApiErrorV1 {
fn from(error: JsonPayloadError) -> Self {
ApiErrorV1::MalformedDataError {
detail: error.to_string(),
code: error_codes::DATA_FORMAT_ERROR.to_string(),
}
}
}
impl From<serde_json::error::Error> for ApiErrorV1 {
fn from(error: serde_json::error::Error) -> Self {
ApiErrorV1::MalformedDataError {
detail: error.to_string(),
code: error_codes::DATA_FORMAT_ERROR.to_string(),
}
}
}
impl ResponseError for ApiErrorV1 {
fn status_code(&self) -> StatusCode {
match self {
ApiErrorV1::MalformedDataError { detail: _, code: _ } => StatusCode::BAD_REQUEST,
ApiErrorV1::InternalError { code: _, detail: _ } => StatusCode::INTERNAL_SERVER_ERROR,
ApiErrorV1::NotFoundError { code: _ } => StatusCode::NOT_FOUND,
ApiErrorV1::OperationForbiddenError { code: _, detail: _ } => StatusCode::FORBIDDEN,
ApiErrorV1::UnauthorizedError { code: _, detail: _ } => StatusCode::UNAUTHORIZED,
}
}
fn error_response(&self) -> HttpResponse {
let response_data = ApiResponseV1 {
status: self.status_code(),
message: self,
};
let mut resp = HttpResponse::build(response_data.status);
if self.status_code().as_u16() == 401 {
resp.append_header((header::WWW_AUTHENTICATE, "Bearer realm=\"restrepo\""));
resp.append_header((header::WWW_AUTHENTICATE, "ApiKey realm=\"restrepo\""));
}
resp.content_type(API_RESPONSE_CONTENT_TYPE)
.append_header((API_RESPONSE_HEADER_KEY, "true"))
.json(response_data.message)
}
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiResponseV1<D> {
#[serde(serialize_with = "serialize_status_code")]
status: StatusCode,
message: D,
}
impl<D: Serialize> ApiResponseV1<D> {
pub fn new(code: StatusCode, msg: D) -> Self {
ApiResponseV1 {
status: code,
message: msg,
}
}
}
fn serialize_status_code<S: Serializer>(code: &StatusCode, ser: S) -> Result<S::Ok, S::Error> {
ser.serialize_u16(code.as_u16())
}
impl<D: Serialize> Responder for ApiResponseV1<D> {
type Body = BoxBody;
fn respond_to(self, _: &HttpRequest) -> HttpResponse<Self::Body> {
let mut resp = HttpResponse::build(self.status);
resp.append_header((API_RESPONSE_HEADER_KEY, "true"));
if self.status.is_informational()
|| self.status.is_redirection()
|| self.status == StatusCode::NO_CONTENT
{
resp.finish()
} else {
resp.content_type(API_RESPONSE_CONTENT_TYPE);
resp.json(self.message)
}
}
}
pub fn to_json_error_response<B>(res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
if res.headers().get(API_RESPONSE_HEADER_KEY).is_some() {
return Ok(ErrorHandlerResponse::Response(res.map_into_left_body()));
};
let status = res.status();
let (req, _) = res.into_parts();
let err = ApiResponseV1::new(
status,
status
.canonical_reason()
.unwrap_or("No error message found"),
);
let err_string = serde_json::to_string(&err).map_err(ErrorInternalServerError)?;
let resp = HttpResponseBuilder::new(status)
.insert_header((header::CONTENT_TYPE, API_RESPONSE_CONTENT_TYPE))
.body(err_string);
Ok(ErrorHandlerResponse::Response(ServiceResponse::new(
req,
resp.map_into_right_body(),
)))
}
#[cfg(test)]
mod tests {
use super::*;
use actix_web::{
App, Error,
http::StatusCode,
http::header,
test,
web::{self, Path},
};
use std::ops::Deref;
async fn success_api_responder() -> Result<ApiResponseV1<&'static str>, Error> {
Ok(ApiResponseV1::new(StatusCode::OK, "Success"))
}
async fn failure_api_responder(test_data: Path<String>) -> Result<ApiResponseV1<u64>, Error> {
let result = test_data
.deref()
.parse()
.map_err(|_| ApiErrorV1::MalformedDataError {
detail: "test data invalid".to_string(),
code: error_codes::DATA_FORMAT_ERROR.to_string(),
})?;
Ok(ApiResponseV1::new(StatusCode::OK, result))
}
#[actix_web::test]
async fn api_response_v1() {
let test_server =
test::init_service(App::new().route("/", web::to(success_api_responder))).await;
let test_request = test::TestRequest::get().uri("/").to_request();
let test_response = test::call_service(&test_server, test_request).await;
assert_eq!(test_response.status().as_u16(), 200);
assert_eq!(
test_response.headers().get(header::CONTENT_TYPE).unwrap(),
API_RESPONSE_CONTENT_TYPE
);
let test_response_body: serde_json::Value = test::read_body_json(test_response).await;
assert_eq!(test_response_body, "Success");
}
#[actix_web::test]
async fn api_error_response_v1() {
let test_server =
test::init_service(App::new().route("/{test_data}", web::to(failure_api_responder)))
.await;
let test_request = test::TestRequest::get().uri("/1234abcd").to_request();
let test_response = test::call_service(&test_server, test_request).await;
assert_eq!(
test_response.headers().get(header::CONTENT_TYPE).unwrap(),
API_RESPONSE_CONTENT_TYPE
);
assert!(test_response.status() == StatusCode::BAD_REQUEST);
assert!(
test_response
.headers()
.get(API_RESPONSE_HEADER_KEY)
.is_some(),
);
let test_response_body: serde_json::Value = test::read_body_json(test_response).await;
assert_eq!(
serde_json::json!(ApiErrorV1::MalformedDataError {
detail: "test data invalid".to_string(),
code: "E0002".to_string()
}),
test_response_body
)
}
}