use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Mutex;
pub use fastapi_types::Method;
use asupersync::stream::Stream;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RequestBodyStreamError {
TooLarge { received: usize, max: usize },
ConnectionClosed,
Io(String),
}
impl fmt::Display for RequestBodyStreamError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooLarge { received, max } => write!(
f,
"request body too large: received {received} bytes (max {max})"
),
Self::ConnectionClosed => write!(f, "connection closed while reading request body"),
Self::Io(e) => write!(f, "I/O error while reading request body: {e}"),
}
}
}
impl std::error::Error for RequestBodyStreamError {}
pub type RequestBodyStream =
Pin<Box<dyn Stream<Item = Result<Vec<u8>, RequestBodyStreamError>> + Send>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum HttpVersion {
Http10,
Http11,
Http2,
}
impl HttpVersion {
#[must_use]
pub fn parse(s: &str) -> Option<Self> {
match s {
"HTTP/1.0" => Some(Self::Http10),
"HTTP/1.1" => Some(Self::Http11),
"HTTP/2" | "HTTP/2.0" => Some(Self::Http2),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConnectionInfo {
pub is_tls: bool,
}
impl ConnectionInfo {
#[allow(dead_code)]
pub const HTTP: Self = Self { is_tls: false };
#[allow(dead_code)]
pub const HTTPS: Self = Self { is_tls: true };
}
#[derive(Debug, Default)]
pub struct Headers {
inner: HashMap<String, Vec<u8>>,
}
impl Headers {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&[u8]> {
self.inner
.get(&name.to_ascii_lowercase())
.map(Vec::as_slice)
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
self.inner.contains_key(&name.to_ascii_lowercase())
}
pub fn insert(&mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) {
self.inner
.insert(name.into().to_ascii_lowercase(), value.into());
}
pub fn insert_from_slice(&mut self, name: &str, value: &[u8]) {
self.inner.insert(name.to_ascii_lowercase(), value.to_vec());
}
pub fn remove(&mut self, name: &str) -> Option<Vec<u8>> {
self.inner.remove(&name.to_ascii_lowercase())
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &[u8])> {
self.inner
.iter()
.map(|(name, value)| (name.as_str(), value.as_slice()))
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
pub enum Body {
Empty,
Bytes(Vec<u8>),
Stream {
stream: Mutex<RequestBodyStream>,
content_length: Option<usize>,
},
}
impl fmt::Debug for Body {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Empty => f.debug_tuple("Empty").finish(),
Self::Bytes(b) => f.debug_tuple("Bytes").field(b).finish(),
Self::Stream { content_length, .. } => f
.debug_struct("Stream")
.field("content_length", content_length)
.finish(),
}
}
}
impl Body {
#[must_use]
pub fn streaming<S>(stream: S) -> Self
where
S: Stream<Item = Result<Vec<u8>, RequestBodyStreamError>> + Send + 'static,
{
Self::Stream {
stream: Mutex::new(Box::pin(stream)),
content_length: None,
}
}
#[must_use]
pub fn streaming_with_size<S>(stream: S, content_length: usize) -> Self
where
S: Stream<Item = Result<Vec<u8>, RequestBodyStreamError>> + Send + 'static,
{
Self::Stream {
stream: Mutex::new(Box::pin(stream)),
content_length: Some(content_length),
}
}
#[must_use]
pub fn into_bytes(self) -> Vec<u8> {
match self {
Self::Empty => Vec::new(),
Self::Bytes(b) => b,
Self::Stream { .. } => Vec::new(),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
matches!(self, Self::Empty)
|| matches!(self, Self::Bytes(b) if b.is_empty())
|| matches!(
self,
Self::Stream {
content_length: Some(0),
..
}
)
}
pub fn into_stream(self) -> Option<(RequestBodyStream, Option<usize>)> {
match self {
Self::Stream {
stream,
content_length,
} => Some((
stream.into_inner().unwrap_or_else(|e| e.into_inner()),
content_length,
)),
_ => None,
}
}
}
pub type BackgroundTasksInner = Mutex<Vec<Pin<Box<dyn Future<Output = ()> + Send>>>>;
pub struct BackgroundTasks {
tasks: BackgroundTasksInner,
}
impl fmt::Debug for BackgroundTasks {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BackgroundTasks").finish_non_exhaustive()
}
}
impl BackgroundTasks {
#[must_use]
pub fn new() -> Self {
Self {
tasks: Mutex::new(Vec::new()),
}
}
pub fn add<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
self.add_async(async move { f() });
}
pub fn add_async<Fut>(&self, fut: Fut)
where
Fut: Future<Output = ()> + Send + 'static,
{
let mut guard = self
.tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.push(Box::pin(fut));
}
pub async fn execute_all(self) {
let tasks = self
.tasks
.into_inner()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for t in tasks {
t.await;
}
}
}
impl Default for BackgroundTasks {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct Request {
method: Method,
version: HttpVersion,
path: String,
query: Option<String>,
headers: Headers,
body: Body,
#[allow(dead_code)] extensions: HashMap<std::any::TypeId, Box<dyn std::any::Any + Send + Sync>>,
}
impl Request {
#[must_use]
pub fn new(method: Method, path: impl Into<String>) -> Self {
Self {
method,
version: HttpVersion::Http11,
path: path.into(),
query: None,
headers: Headers::new(),
body: Body::Empty,
extensions: HashMap::new(),
}
}
#[must_use]
pub fn with_version(method: Method, path: impl Into<String>, version: HttpVersion) -> Self {
let mut req = Self::new(method, path);
req.version = version;
req
}
#[must_use]
pub fn method(&self) -> Method {
self.method
}
#[must_use]
pub fn version(&self) -> HttpVersion {
self.version
}
pub fn set_version(&mut self, version: HttpVersion) {
self.version = version;
}
#[must_use]
pub fn path(&self) -> &str {
&self.path
}
#[must_use]
pub fn query(&self) -> Option<&str> {
self.query.as_deref()
}
#[must_use]
pub fn headers(&self) -> &Headers {
&self.headers
}
pub fn headers_mut(&mut self) -> &mut Headers {
&mut self.headers
}
#[must_use]
pub fn body(&self) -> &Body {
&self.body
}
pub fn take_body(&mut self) -> Body {
std::mem::replace(&mut self.body, Body::Empty)
}
pub fn set_body(&mut self, body: Body) {
self.body = body;
}
pub fn set_query(&mut self, query: Option<String>) {
self.query = query;
}
pub fn insert_extension<T: Any + Send + Sync>(&mut self, value: T) {
self.extensions.insert(TypeId::of::<T>(), Box::new(value));
}
#[must_use]
pub fn get_extension<T: Any + Send + Sync>(&self) -> Option<&T> {
self.extensions
.get(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_ref::<T>())
}
pub fn get_extension_mut<T: Any + Send + Sync>(&mut self) -> Option<&mut T> {
self.extensions
.get_mut(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_mut::<T>())
}
pub fn take_extension<T: Any + Send + Sync>(&mut self) -> Option<T> {
self.extensions
.remove(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast::<T>().ok())
.map(|boxed| *boxed)
}
pub fn background_tasks(&mut self) -> &BackgroundTasks {
if !self
.extensions
.contains_key(&TypeId::of::<BackgroundTasks>())
{
self.insert_extension(BackgroundTasks::new());
}
self.get_extension::<BackgroundTasks>()
.expect("BackgroundTasks extension should exist")
}
}