use core::fmt;
use crate::{
io::{Read, Write},
sync::oneshot_broadcast,
KeepAlive, ResponseSent,
};
pub use picoserve_derive::ErrorWithStatusCode;
pub mod chunked;
pub mod custom;
pub mod fs;
pub mod sse;
pub mod status;
pub mod with_state;
#[cfg(feature = "json")]
pub mod json;
#[cfg(feature = "ws")]
pub mod ws;
pub(crate) mod response_stream;
pub(crate) use response_stream::ResponseStream;
pub use fs::{Directory, File};
pub use sse::EventStream;
pub use status::StatusCode;
pub use with_state::{ContentUsingState, IntoResponseWithState};
#[cfg(feature = "json")]
pub use json::Json;
#[cfg(feature = "ws")]
pub use ws::WebSocketUpgrade;
#[cfg(feature = "ws")] fn assert_implements_into_response<T: IntoResponse>(t: T) -> T {
t
}
#[cfg(feature = "ws")] fn assert_implements_into_response_with_state<State, T: IntoResponseWithState<State>>(t: T) -> T {
t
}
struct MeasureFormatSize<'a>(&'a mut usize);
impl fmt::Write for MeasureFormatSize<'_> {
fn write_str(&mut self, s: &str) -> fmt::Result {
*self.0 += s.len();
Ok(())
}
}
pub(crate) enum AfterBodyReadMode<'r> {
ReadFromReader,
ReadFromBuffer {
remaining: &'r [u8],
},
SkipRemainingBodyFromReader {
scratch_buffer: &'r mut [u8],
body_bytes_remaining: usize,
},
}
pub(crate) struct AfterBodyReader<'r, R: Read> {
pub(crate) mode: AfterBodyReadMode<'r>,
pub(crate) reader: R,
}
impl<R: Read> AfterBodyReader<'_, R> {
async fn read_after_body(
&mut self,
buffer: &mut [u8],
_upgrade_token: &crate::extract::UpgradeToken,
) -> Result<usize, R::Error> {
loop {
break match &mut self.mode {
AfterBodyReadMode::ReadFromReader => self.reader.read(buffer).await,
AfterBodyReadMode::ReadFromBuffer { remaining } => {
if remaining.is_empty() {
self.mode = AfterBodyReadMode::ReadFromReader;
continue;
}
let read_size = remaining.len().min(buffer.len());
buffer[..read_size].copy_from_slice(&remaining[..read_size]);
*remaining = &remaining[read_size..];
Ok(read_size)
}
AfterBodyReadMode::SkipRemainingBodyFromReader {
scratch_buffer,
body_bytes_remaining,
} => {
if *body_bytes_remaining == 0 {
self.mode = AfterBodyReadMode::ReadFromReader;
continue;
}
let read_buffer_size = (*body_bytes_remaining).min(scratch_buffer.len());
let read_size = self
.reader
.read(&mut scratch_buffer[..read_buffer_size])
.await?;
*body_bytes_remaining -= read_size;
if read_size == 0 {
self.mode = AfterBodyReadMode::ReadFromReader;
Ok(0)
} else {
continue;
}
}
};
}
}
}
pub struct UpgradedConnection<'r, R: Read> {
upgrade_token: crate::extract::UpgradeToken,
reader: AfterBodyReader<'r, R>,
}
impl<R: Read> crate::io::ErrorType for UpgradedConnection<'_, R> {
type Error = R::Error;
}
impl<R: Read> Read for UpgradedConnection<'_, R> {
async fn read(&mut self, buffer: &mut [u8]) -> Result<usize, Self::Error> {
self.reader
.read_after_body(buffer, &self.upgrade_token)
.await
}
}
pub struct Connection<'r, R: Read> {
pub(crate) reader: AfterBodyReader<'r, R>,
pub(crate) connection_flags: &'r mut crate::request::ConnectionFlags,
pub(crate) shutdown_signal: oneshot_broadcast::Listener<'r, ()>,
}
impl<'r, R: Read> Connection<'r, R> {
pub fn upgrade(self, upgrade_token: crate::extract::UpgradeToken) -> UpgradedConnection<'r, R> {
self.connection_flags.notify_connection_has_been_upgraded();
UpgradedConnection {
upgrade_token,
reader: self.reader,
}
}
pub async fn wait_for_disconnection(self) -> Result<(), R::Error> {
crate::extract::UpgradeToken::discard_all_data(self).await
}
pub async fn run_until_disconnection<T>(
self,
default: T,
action: impl core::future::Future<Output = Result<T, R::Error>>,
) -> Result<T, R::Error> {
crate::futures::select(action, async {
self.wait_for_disconnection().await?;
Ok(default)
})
.await
}
}
pub(crate) struct EmptyReader<E: crate::io::Error>(core::marker::PhantomData<E>);
impl<E: crate::io::Error> crate::io::ErrorType for EmptyReader<E> {
type Error = E;
}
impl<E: crate::io::Error> crate::io::Read for EmptyReader<E> {
async fn read(&mut self, _buf: &mut [u8]) -> Result<usize, Self::Error> {
Ok(0)
}
}
pub(crate) struct EmptyParts {
connection_flags: crate::request::ConnectionFlags,
shutdown_signal: oneshot_broadcast::SignalCore<()>,
}
impl Default for EmptyParts {
fn default() -> Self {
Self {
connection_flags: crate::request::ConnectionFlags::new(),
shutdown_signal: oneshot_broadcast::Signal::core(),
}
}
}
impl<'r, E: crate::io::Error> Connection<'r, EmptyReader<E>> {
pub(crate) fn empty(
EmptyParts {
connection_flags,
shutdown_signal,
}: &'r mut EmptyParts,
) -> Self {
Self {
reader: AfterBodyReader {
mode: AfterBodyReadMode::ReadFromReader,
reader: EmptyReader(core::marker::PhantomData),
},
connection_flags,
shutdown_signal: shutdown_signal.make_signal().listen(),
}
}
}
#[doc(hidden)]
pub trait ForEachHeader {
type Output;
type Error;
async fn call<Value: fmt::Display>(
&mut self,
name: &str,
value: Value,
) -> Result<(), Self::Error>;
async fn finalize(self) -> Result<Self::Output, Self::Error>;
}
struct BorrowedForEachHeader<'a, F: ForEachHeader>(&'a mut F);
impl<F: ForEachHeader> ForEachHeader for BorrowedForEachHeader<'_, F> {
type Output = ();
type Error = F::Error;
async fn call<Value: fmt::Display>(
&mut self,
name: &str,
value: Value,
) -> Result<(), F::Error> {
self.0.call(name, value).await
}
async fn finalize(self) -> Result<Self::Output, Self::Error> {
Ok(())
}
}
pub trait HeadersIter {
async fn for_each_header<F: ForEachHeader>(self, f: F) -> Result<F::Output, F::Error>;
}
impl<V: fmt::Display> HeadersIter for (&str, V) {
async fn for_each_header<F: ForEachHeader>(self, mut f: F) -> Result<F::Output, F::Error> {
let (name, value) = self;
f.call(name, value).await?;
f.finalize().await
}
}
impl<V: fmt::Display> HeadersIter for &[(&str, V)] {
async fn for_each_header<F: ForEachHeader>(self, mut f: F) -> Result<F::Output, F::Error> {
for (name, value) in self {
f.call(name, value).await?;
}
f.finalize().await
}
}
impl<H: HeadersIter, const N: usize> HeadersIter for [H; N] {
async fn for_each_header<F: ForEachHeader>(self, mut f: F) -> Result<F::Output, F::Error> {
for headers in self {
headers
.for_each_header(BorrowedForEachHeader(&mut f))
.await?;
}
f.finalize().await
}
}
impl<T: HeadersIter> HeadersIter for Option<T> {
async fn for_each_header<F: ForEachHeader>(self, f: F) -> Result<F::Output, F::Error> {
if let Some(headers) = self {
headers.for_each_header(f).await
} else {
f.finalize().await
}
}
}
struct HeadersChain<A: HeadersIter, B: HeadersIter>(A, B);
impl<A: HeadersIter, B: HeadersIter> HeadersIter for HeadersChain<A, B> {
async fn for_each_header<F: ForEachHeader>(self, mut f: F) -> Result<F::Output, F::Error> {
let Self(a, b) = self;
a.for_each_header(BorrowedForEachHeader(&mut f)).await?;
b.for_each_header(BorrowedForEachHeader(&mut f)).await?;
f.finalize().await
}
}
pub trait Body {
async fn write_response_body<R: Read, W: Write<Error = R::Error>>(
self,
connection: Connection<'_, R>,
writer: W,
) -> Result<(), W::Error>;
}
#[doc(hidden)]
pub struct NoBody;
impl Body for NoBody {
async fn write_response_body<R: Read, W: Write<Error = R::Error>>(
self,
_connection: Connection<'_, R>,
_writer: W,
) -> Result<(), W::Error> {
Ok(())
}
}
pub struct NoContent;
pub trait Content {
fn content_type(&self) -> &'static str;
fn content_length(&self) -> usize;
async fn write_content<W: Write>(self, writer: W) -> Result<(), W::Error>;
}
macro_rules! content_methods {
($as:ident) => {
fn content_type(&self) -> &'static str {
self.$as().content_type()
}
fn content_length(&self) -> usize {
self.$as().content_length()
}
async fn write_content<W: Write>(self, writer: W) -> Result<(), W::Error> {
self.$as().write_content(writer).await
}
};
}
#[doc(hidden)]
pub struct ContentBody<C: Content> {
content: C,
}
impl<C: Content> Body for ContentBody<C> {
async fn write_response_body<R: Read, W: Write<Error = R::Error>>(
self,
_connection: Connection<'_, R>,
mut writer: W,
) -> Result<(), W::Error> {
self.content.write_content(&mut writer).await?;
writer.flush().await?;
Ok(())
}
}
impl Content for &[u8] {
fn content_type(&self) -> &'static str {
"application/octet-stream"
}
fn content_length(&self) -> usize {
self.len()
}
async fn write_content<W: Write>(self, mut writer: W) -> Result<(), W::Error> {
writer.write_all(self).await
}
}
impl<const N: usize> Content for heapless::Vec<u8, N> {
content_methods!(as_slice);
}
#[cfg(any(test, feature = "alloc"))]
impl Content for alloc::vec::Vec<u8> {
content_methods!(as_slice);
}
impl Content for &str {
fn content_type(&self) -> &'static str {
"text/plain; charset=utf-8"
}
fn content_length(&self) -> usize {
self.len()
}
async fn write_content<W: Write>(self, writer: W) -> Result<(), W::Error> {
self.as_bytes().write_content(writer).await
}
}
impl<const N: usize> Content for heapless::String<N> {
content_methods!(as_str);
}
#[cfg(any(test, feature = "alloc"))]
impl Content for alloc::string::String {
content_methods!(as_str);
}
impl Content for fmt::Arguments<'_> {
fn content_type(&self) -> &'static str {
"".content_type()
}
fn content_length(&self) -> usize {
use fmt::Write;
let mut size = 0;
write!(MeasureFormatSize(&mut size), "{self}").map_or(0, |()| size)
}
async fn write_content<W: Write>(self, mut writer: W) -> Result<(), W::Error> {
use crate::io::WriteExt;
write!(writer, "{}", self).await
}
}
#[doc(hidden)]
pub struct NoHeaders;
impl HeadersIter for NoHeaders {
async fn for_each_header<F: ForEachHeader>(self, f: F) -> Result<F::Output, F::Error> {
f.finalize().await
}
}
#[doc(hidden)]
pub struct ContentHeaders {
content_type: &'static str,
content_length: usize,
}
impl HeadersIter for ContentHeaders {
async fn for_each_header<F: ForEachHeader>(self, mut f: F) -> Result<F::Output, F::Error> {
f.call("Content-Type", self.content_type).await?;
f.call("Content-Length", self.content_length).await?;
f.finalize().await
}
}
pub struct Response<H: HeadersIter, B: Body> {
pub(crate) status_code: StatusCode,
pub(crate) headers: H,
pub(crate) body: B,
}
impl<C: Content> Response<ContentHeaders, ContentBody<C>> {
pub fn new(status_code: StatusCode, content: C) -> Self {
Self {
status_code,
headers: ContentHeaders {
content_type: content.content_type(),
content_length: content.content_length(),
},
body: ContentBody { content },
}
}
pub fn ok(body: C) -> Self {
Self::new(StatusCode::OK, body)
}
}
impl Response<NoHeaders, NoBody> {
pub fn empty(status_code: StatusCode) -> Self {
Self {
status_code,
headers: NoHeaders,
body: NoBody,
}
}
}
impl<H: HeadersIter, B: Body> Response<H, B> {
pub fn status_code(&self) -> StatusCode {
self.status_code
}
pub fn with_status_code(self, status_code: StatusCode) -> Self {
let Self {
status_code: _,
headers,
body,
} = self;
Self {
status_code,
headers,
body,
}
}
pub fn with_headers<HH: HeadersIter>(self, headers: HH) -> Response<impl HeadersIter, B> {
let Response {
status_code,
headers: current_headers,
body,
} = self;
Response {
status_code,
headers: HeadersChain(current_headers, headers),
body,
}
}
pub fn with_header<V: fmt::Display>(
self,
name: &'static str,
value: V,
) -> Response<impl HeadersIter, B> {
self.with_headers((name, value))
}
}
pub trait ResponseWriter: Sized {
type Error;
async fn write_response<R: Read<Error = Self::Error>, H: HeadersIter, B: Body>(
self,
connection: Connection<'_, R>,
response: Response<H, B>,
) -> Result<ResponseSent, Self::Error>;
}
pub trait IntoResponse: Sized {
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(
self,
connection: Connection<'_, R>,
response_writer: W,
) -> Result<ResponseSent, W::Error>;
}
impl<C: Content> IntoResponse for C {
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(
self,
connection: Connection<'_, R>,
response_writer: W,
) -> Result<ResponseSent, W::Error> {
response_writer
.write_response(connection, Response::ok(self))
.await
}
}
impl<H: HeadersIter, B: Body> IntoResponse for Response<H, B> {
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(
self,
connection: Connection<'_, R>,
response_writer: W,
) -> Result<ResponseSent, W::Error> {
response_writer.write_response(connection, self).await
}
}
impl IntoResponse for core::convert::Infallible {
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(
self,
_connection: Connection<'_, R>,
_response_writer: W,
) -> Result<ResponseSent, W::Error> {
match self {}
}
}
impl IntoResponse for () {
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(
self,
connection: Connection<'_, R>,
response_writer: W,
) -> Result<ResponseSent, W::Error> {
"OK\n".write_to(connection, response_writer).await
}
}
impl<T: IntoResponse, E: IntoResponse> IntoResponse for Result<T, E> {
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(
self,
connection: Connection<'_, R>,
response_writer: W,
) -> Result<ResponseSent, W::Error> {
match self {
Ok(value) => value.write_to(connection, response_writer).await,
Err(err) => err.write_to(connection, response_writer).await,
}
}
}
macro_rules! declare_tuple_into_response {
($($($name:ident)*;)*) => {
$(
impl<$($name: HeadersIter,)* C: Content> IntoResponse for (StatusCode, $($name,)* C,) {
#[allow(non_snake_case)]
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(self, connection: Connection<'_, R>, response_writer: W) -> Result<ResponseSent, W::Error> {
let (status_code, $($name,)* body) = self;
response_writer.write_response(
connection,
Response::new(status_code, body)
$(.with_headers($name,))*
).await
}
}
impl<$($name: HeadersIter,)* C: Content> IntoResponse for ($($name,)* C,) {
#[allow(non_snake_case)]
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(self, connection: Connection<'_, R>, response_writer: W) -> Result<ResponseSent, W::Error> {
let ($($name,)* body,) = self;
response_writer.write_response(
connection,
Response::new(StatusCode::OK, body)
$(.with_headers($name,))*
).await
}
}
impl<$($name: HeadersIter,)*> IntoResponse for (StatusCode, $($name,)* NoContent,) {
#[allow(non_snake_case)]
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(self, connection: Connection<'_, R>, response_writer: W) -> Result<ResponseSent, W::Error> {
let (status_code, $($name,)* NoContent,) = self;
response_writer.write_response(
connection,
Response::empty(status_code)
$(.with_headers($name,))*
).await
}
}
)*
};
}
declare_tuple_into_response!(
;
H1;
H1 H2;
H1 H2 H3;
H1 H2 H3 H4;
H1 H2 H3 H4 H5;
H1 H2 H3 H4 H5 H6;
H1 H2 H3 H4 H5 H6 H7;
H1 H2 H3 H4 H5 H6 H7 H8;
H1 H2 H3 H4 H5 H6 H7 H8 H9;
H1 H2 H3 H4 H5 H6 H7 H8 H9 H10;
H1 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11;
H1 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12;
H1 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13;
H1 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13 H14;
H1 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13 H14 H15;
H1 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13 H14 H15 H16;
);
pub struct DebugValue<D>(pub D);
impl<D: fmt::Debug> IntoResponse for DebugValue<D> {
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(
self,
connection: Connection<'_, R>,
response_writer: W,
) -> Result<ResponseSent, W::Error> {
response_writer
.write_response(connection, Response::ok(format_args!("{:?}\r\n", self.0)))
.await
}
}
pub struct Redirect {
status_code: StatusCode,
location: &'static str,
}
impl Redirect {
pub fn to(location: &'static str) -> Self {
Self {
status_code: StatusCode::SEE_OTHER,
location,
}
}
}
impl IntoResponse for Redirect {
async fn write_to<R: Read, W: ResponseWriter<Error = R::Error>>(
self,
connection: Connection<'_, R>,
response_writer: W,
) -> Result<ResponseSent, W::Error> {
(
self.status_code,
("Location", self.location),
format_args!("{}\n", self.location),
)
.write_to(connection, response_writer)
.await
}
}
pub trait ErrorWithStatusCode: fmt::Display + IntoResponse {
fn status_code(&self) -> StatusCode;
}