use futures_lite::{io, AsyncWrite};
use futures_util::TryStreamExt;
use http::{
header::{HeaderMap, HeaderValue, IntoHeaderName},
status::StatusCode,
version::Version,
};
use tracing::error;
use crate::body::Body;
use crate::response::Response;
use super::encode::Encoder;
use super::glitch::Glitch;
pin_project_lite::pin_project! {
pub(crate) struct InnerResponse {
pub(crate) status: StatusCode,
pub(crate) headers: HeaderMap,
pub(crate) version: Version,
#[pin]
pub(crate)body: Body,
}
}
impl InnerResponse {
pub(crate) fn bad_request() -> Self {
Self {
status: StatusCode::BAD_REQUEST,
headers: HeaderMap::new(),
version: Version::default(),
body: Body::empty(),
}
}
pub(crate) fn version_not_supported() -> Self {
Self {
status: StatusCode::HTTP_VERSION_NOT_SUPPORTED,
headers: HeaderMap::new(),
version: Version::default(),
body: Body::empty(),
}
}
pub(crate) fn not_implemented() -> Self {
Self {
status: StatusCode::NOT_IMPLEMENTED,
headers: HeaderMap::new(),
version: Version::default(),
body: Body::empty(),
}
}
pub(crate) async fn send<W>(self, writer: W) -> Result<ResponseWritten, std::io::Error>
where
W: AsyncWrite + Clone + Send + Sync + Unpin + 'static,
{
let mut encoder = Encoder::encode(self);
let mut writer = writer;
let bytes_written = match io::copy(&mut encoder, &mut writer).await {
Ok(b) => b,
Err(err) => {
error!("Error sending response: {}", err);
return Err(err);
}
};
Ok(ResponseWritten { bytes_written })
}
}
pub struct ResponseWriter<W>
where
W: AsyncWrite + Clone + Send + Sync + Unpin + 'static,
{
pub(crate) response: Response,
pub(crate) writer: W,
}
impl<W> ResponseWriter<W>
where
W: AsyncWrite + Clone + Send + Sync + Unpin + 'static,
{
pub async fn send(self) -> Result<ResponseWritten, Glitch> {
let (parts, body) = self.response.into_parts();
let inner_resp = InnerResponse {
status: parts.status,
headers: parts.headers,
version: parts.version,
body,
};
Ok(inner_resp.send(self.writer).await?)
}
pub async fn send_code(self, code: u16) -> Result<ResponseWritten, Glitch> {
let mut this = self;
this.set_code(code);
this.send().await
}
pub fn set_status(&mut self, status: http::StatusCode) -> &mut Self {
*self.response.status_mut() = status;
self
}
pub fn set_code(&mut self, code: u16) -> &mut Self {
*self.response.status_mut() = http::StatusCode::from_u16(code).unwrap();
self
}
pub fn set_body(&mut self, body: Body) -> &mut Self {
*self.response.body_mut() = body;
self
}
pub fn append_header(
&mut self,
header_name: impl IntoHeaderName,
header_value: HeaderValue,
) -> &mut Self {
self.response
.headers_mut()
.append(header_name, header_value);
self
}
pub fn insert_header(
&mut self,
header_name: impl IntoHeaderName,
header_value: HeaderValue,
) -> &mut Self {
self.response
.headers_mut()
.insert(header_name, header_value);
self
}
pub fn response_mut(&mut self) -> &mut Response {
&mut self.response
}
pub fn response(&self) -> &Response {
&self.response
}
pub fn set_text(&mut self, text: String) -> &mut Self {
*self.response.body_mut() = text.into();
self.response
.headers_mut()
.insert(http::header::CONTENT_TYPE, "text/plain".parse().unwrap());
self
}
pub fn set_sse<S>(&mut self, stream: S)
where
S: TryStreamExt<Error = std::io::Error> + Send + Sync + Unpin + 'static,
S::Ok: AsRef<[u8]> + Send + Sync,
{
let stream = stream.into_async_read();
self.set_body(Body::from_reader(stream, None));
self.insert_header(http::header::CONTENT_TYPE, "text/event-stream".parse().unwrap());
}
}
pub struct ResponseWritten {
bytes_written: u64,
}
impl ResponseWritten {
pub fn bytes_written(&self) -> u64 {
self.bytes_written
}
}