use crate::error::{ApiError, ErrorResponse};
use bytes::Bytes;
use futures_util::StreamExt;
use http::{header, HeaderMap, HeaderValue, StatusCode};
use http_body_util::Full;
use rustapi_openapi::schema::{RustApiSchema, SchemaCtx};
use rustapi_openapi::{MediaType, Operation, ResponseModifier, ResponseSpec, SchemaRef};
use serde::Serialize;
use std::collections::BTreeMap;
use std::pin::Pin;
use std::task::{Context, Poll};
pub enum Body {
Full(Full<Bytes>),
Streaming(Pin<Box<dyn http_body::Body<Data = Bytes, Error = ApiError> + Send + 'static>>),
}
impl Body {
pub fn new(bytes: Bytes) -> Self {
Self::Full(Full::new(bytes))
}
pub fn empty() -> Self {
Self::Full(Full::new(Bytes::new()))
}
pub fn from_stream<S, E>(stream: S) -> Self
where
S: futures_util::Stream<Item = Result<Bytes, E>> + Send + 'static,
E: Into<ApiError> + 'static,
{
let body = http_body_util::StreamBody::new(
stream.map(|res| res.map_err(|e| e.into()).map(http_body::Frame::data)),
);
Self::Streaming(Box::pin(body))
}
}
impl Default for Body {
fn default() -> Self {
Self::empty()
}
}
impl http_body::Body for Body {
type Data = Bytes;
type Error = ApiError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
match self.get_mut() {
Body::Full(b) => Pin::new(b)
.poll_frame(cx)
.map_err(|_| ApiError::internal("Infallible error")),
Body::Streaming(b) => b.as_mut().poll_frame(cx),
}
}
fn is_end_stream(&self) -> bool {
match self {
Body::Full(b) => b.is_end_stream(),
Body::Streaming(b) => b.is_end_stream(),
}
}
fn size_hint(&self) -> http_body::SizeHint {
match self {
Body::Full(b) => b.size_hint(),
Body::Streaming(b) => b.size_hint(),
}
}
}
impl From<Bytes> for Body {
fn from(bytes: Bytes) -> Self {
Self::new(bytes)
}
}
impl From<String> for Body {
fn from(s: String) -> Self {
Self::new(Bytes::from(s))
}
}
impl From<&'static str> for Body {
fn from(s: &'static str) -> Self {
Self::new(Bytes::from(s))
}
}
impl From<Vec<u8>> for Body {
fn from(v: Vec<u8>) -> Self {
Self::new(Bytes::from(v))
}
}
pub type Response = http::Response<Body>;
pub trait IntoResponse {
fn into_response(self) -> Response;
}
impl IntoResponse for Response {
fn into_response(self) -> Response {
self
}
}
impl IntoResponse for () {
fn into_response(self) -> Response {
http::Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.unwrap()
}
}
impl IntoResponse for &'static str {
fn into_response(self) -> Response {
http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(Body::from(self))
.unwrap()
}
}
impl IntoResponse for String {
fn into_response(self) -> Response {
http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(Body::from(self))
.unwrap()
}
}
impl IntoResponse for StatusCode {
fn into_response(self) -> Response {
http::Response::builder()
.status(self)
.body(Body::empty())
.unwrap()
}
}
impl<R: IntoResponse> IntoResponse for (StatusCode, R) {
fn into_response(self) -> Response {
let mut response = self.1.into_response();
*response.status_mut() = self.0;
response
}
}
impl<R: IntoResponse> IntoResponse for (StatusCode, HeaderMap, R) {
fn into_response(self) -> Response {
let mut response = self.2.into_response();
*response.status_mut() = self.0;
response.headers_mut().extend(self.1);
response
}
}
impl<T: IntoResponse, E: IntoResponse> IntoResponse for Result<T, E> {
fn into_response(self) -> Response {
match self {
Ok(v) => v.into_response(),
Err(e) => e.into_response(),
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let status = self.status;
let error_response = ErrorResponse::from(self);
let body = serde_json::to_vec(&error_response).unwrap_or_else(|_| {
br#"{"error":{"type":"internal_error","message":"Failed to serialize error"}}"#.to_vec()
});
http::Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body))
.unwrap()
}
}
impl ResponseModifier for ApiError {
fn update_response(op: &mut Operation) {
op.responses.insert(
"400".to_string(),
ResponseSpec {
description: "Bad Request".to_string(),
content: {
let mut map = BTreeMap::new();
map.insert(
"application/json".to_string(),
MediaType {
schema: Some(SchemaRef::Ref {
reference: "#/components/schemas/ErrorSchema".to_string(),
}),
example: None,
},
);
map
},
headers: BTreeMap::new(),
},
);
op.responses.insert(
"500".to_string(),
ResponseSpec {
description: "Internal Server Error".to_string(),
content: {
let mut map = BTreeMap::new();
map.insert(
"application/json".to_string(),
MediaType {
schema: Some(SchemaRef::Ref {
reference: "#/components/schemas/ErrorSchema".to_string(),
}),
example: None,
},
);
map
},
headers: BTreeMap::new(),
},
);
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<rustapi_openapi::ErrorSchema>();
spec.register_in_place::<rustapi_openapi::ErrorBodySchema>();
spec.register_in_place::<rustapi_openapi::ValidationErrorSchema>();
spec.register_in_place::<rustapi_openapi::ValidationErrorBodySchema>();
spec.register_in_place::<rustapi_openapi::FieldErrorSchema>();
}
}
#[derive(Debug, Clone)]
pub struct Created<T>(pub T);
impl<T: Serialize> IntoResponse for Created<T> {
fn into_response(self) -> Response {
match serde_json::to_vec(&self.0) {
Ok(body) => http::Response::builder()
.status(StatusCode::CREATED)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body))
.unwrap(),
Err(err) => {
ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
}
}
}
}
impl<T: RustApiSchema> ResponseModifier for Created<T> {
fn update_response(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
op.responses.insert(
"201".to_string(),
ResponseSpec {
description: "Created".to_string(),
content: {
let mut map = BTreeMap::new();
map.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
map
},
headers: BTreeMap::new(),
},
);
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
}
}
#[derive(Debug, Clone, Copy)]
pub struct NoContent;
impl IntoResponse for NoContent {
fn into_response(self) -> Response {
http::Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
.unwrap()
}
}
impl ResponseModifier for NoContent {
fn update_response(op: &mut Operation) {
op.responses.insert(
"204".to_string(),
ResponseSpec {
description: "No Content".to_string(),
content: BTreeMap::new(),
headers: BTreeMap::new(),
},
);
}
}
#[derive(Debug, Clone)]
pub struct Html<T>(pub T);
impl<T: Into<String>> IntoResponse for Html<T> {
fn into_response(self) -> Response {
http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(Body::from(self.0.into()))
.unwrap()
}
}
impl<T> ResponseModifier for Html<T> {
fn update_response(op: &mut Operation) {
op.responses.insert(
"200".to_string(),
ResponseSpec {
description: "HTML Content".to_string(),
content: {
let mut map = BTreeMap::new();
map.insert(
"text/html".to_string(),
MediaType {
schema: Some(SchemaRef::Inline(
serde_json::json!({ "type": "string" }),
)),
example: None,
},
);
map
},
headers: BTreeMap::new(),
},
);
}
}
#[derive(Debug, Clone)]
pub struct Redirect {
status: StatusCode,
location: HeaderValue,
}
impl Redirect {
pub fn to(uri: &str) -> Self {
Self {
status: StatusCode::FOUND,
location: HeaderValue::from_str(uri).expect("Invalid redirect URI"),
}
}
pub fn permanent(uri: &str) -> Self {
Self {
status: StatusCode::MOVED_PERMANENTLY,
location: HeaderValue::from_str(uri).expect("Invalid redirect URI"),
}
}
pub fn temporary(uri: &str) -> Self {
Self {
status: StatusCode::TEMPORARY_REDIRECT,
location: HeaderValue::from_str(uri).expect("Invalid redirect URI"),
}
}
}
impl IntoResponse for Redirect {
fn into_response(self) -> Response {
http::Response::builder()
.status(self.status)
.header(header::LOCATION, self.location)
.body(Body::empty())
.unwrap()
}
}
impl ResponseModifier for Redirect {
fn update_response(op: &mut Operation) {
op.responses.insert(
"3xx".to_string(),
ResponseSpec {
description: "Redirection".to_string(),
content: BTreeMap::new(),
headers: BTreeMap::new(),
},
);
}
}
#[derive(Debug, Clone)]
pub struct WithStatus<T, const CODE: u16>(pub T);
impl<T: IntoResponse, const CODE: u16> IntoResponse for WithStatus<T, CODE> {
fn into_response(self) -> Response {
let mut response = self.0.into_response();
if let Ok(status) = StatusCode::from_u16(CODE) {
*response.status_mut() = status;
}
response
}
}
impl<T: RustApiSchema, const CODE: u16> ResponseModifier for WithStatus<T, CODE> {
fn update_response(op: &mut Operation) {
let mut ctx = SchemaCtx::new();
let schema_ref = T::schema(&mut ctx);
op.responses.insert(
CODE.to_string(),
ResponseSpec {
description: format!("Response with status {}", CODE),
content: {
let mut map = BTreeMap::new();
map.insert(
"application/json".to_string(),
MediaType {
schema: Some(schema_ref),
example: None,
},
);
map
},
headers: BTreeMap::new(),
},
);
}
fn register_components(spec: &mut rustapi_openapi::OpenApiSpec) {
spec.register_in_place::<T>();
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
async fn body_to_bytes(body: Body) -> Bytes {
use http_body_util::BodyExt;
body.collect().await.unwrap().to_bytes()
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_with_status_response_correctness(
body in "[a-zA-Z0-9 ]{0,100}",
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let response_200: Response = WithStatus::<_, 200>(body.clone()).into_response();
prop_assert_eq!(response_200.status().as_u16(), 200);
let response_201: Response = WithStatus::<_, 201>(body.clone()).into_response();
prop_assert_eq!(response_201.status().as_u16(), 201);
let response_202: Response = WithStatus::<_, 202>(body.clone()).into_response();
prop_assert_eq!(response_202.status().as_u16(), 202);
let response_204: Response = WithStatus::<_, 204>(body.clone()).into_response();
prop_assert_eq!(response_204.status().as_u16(), 204);
let response_400: Response = WithStatus::<_, 400>(body.clone()).into_response();
prop_assert_eq!(response_400.status().as_u16(), 400);
let response_404: Response = WithStatus::<_, 404>(body.clone()).into_response();
prop_assert_eq!(response_404.status().as_u16(), 404);
let response_418: Response = WithStatus::<_, 418>(body.clone()).into_response();
prop_assert_eq!(response_418.status().as_u16(), 418);
let response_500: Response = WithStatus::<_, 500>(body.clone()).into_response();
prop_assert_eq!(response_500.status().as_u16(), 500);
let response_503: Response = WithStatus::<_, 503>(body.clone()).into_response();
prop_assert_eq!(response_503.status().as_u16(), 503);
let response_for_body: Response = WithStatus::<_, 200>(body.clone()).into_response();
let body_bytes = body_to_bytes(response_for_body.into_body()).await;
let body_str = String::from_utf8_lossy(&body_bytes);
prop_assert_eq!(body_str.as_ref(), body.as_str());
Ok(())
})?;
}
}
#[tokio::test]
async fn test_with_status_preserves_content_type() {
let response: Response = WithStatus::<_, 202>("hello world").into_response();
assert_eq!(response.status().as_u16(), 202);
assert_eq!(
response.headers().get(header::CONTENT_TYPE).unwrap(),
"text/plain; charset=utf-8"
);
}
#[tokio::test]
async fn test_with_status_with_empty_body() {
let response: Response = WithStatus::<_, 204>(()).into_response();
assert_eq!(response.status().as_u16(), 204);
let body_bytes = body_to_bytes(response.into_body()).await;
assert!(body_bytes.is_empty());
}
#[test]
fn test_with_status_common_codes() {
assert_eq!(
WithStatus::<_, 100>("").into_response().status().as_u16(),
100
); assert_eq!(
WithStatus::<_, 200>("").into_response().status().as_u16(),
200
); assert_eq!(
WithStatus::<_, 201>("").into_response().status().as_u16(),
201
); assert_eq!(
WithStatus::<_, 202>("").into_response().status().as_u16(),
202
); assert_eq!(
WithStatus::<_, 204>("").into_response().status().as_u16(),
204
); assert_eq!(
WithStatus::<_, 301>("").into_response().status().as_u16(),
301
); assert_eq!(
WithStatus::<_, 302>("").into_response().status().as_u16(),
302
); assert_eq!(
WithStatus::<_, 400>("").into_response().status().as_u16(),
400
); assert_eq!(
WithStatus::<_, 401>("").into_response().status().as_u16(),
401
); assert_eq!(
WithStatus::<_, 403>("").into_response().status().as_u16(),
403
); assert_eq!(
WithStatus::<_, 404>("").into_response().status().as_u16(),
404
); assert_eq!(
WithStatus::<_, 500>("").into_response().status().as_u16(),
500
); assert_eq!(
WithStatus::<_, 502>("").into_response().status().as_u16(),
502
); assert_eq!(
WithStatus::<_, 503>("").into_response().status().as_u16(),
503
); }
}