use crate::{HttpError, ServerFnError};
use axum_core::extract::FromRequest;
use axum_core::response::IntoResponse;
use dioxus_core::{CapturedError, ReactiveContext};
use http::StatusCode;
use http::{request::Parts, HeaderMap};
use parking_lot::RwLock;
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct FullstackContext {
request_headers: Arc<RwLock<http::request::Parts>>,
lock: Arc<RwLock<FullstackContextInner>>,
}
tokio::task_local! {
static FULLSTACK_CONTEXT: FullstackContext;
}
pub struct FullstackContextInner {
current_status: StreamingStatus,
current_status_subscribers: HashSet<ReactiveContext>,
response_headers: Option<HeaderMap>,
route_http_status: HttpError,
route_http_status_subscribers: HashSet<ReactiveContext>,
}
impl Debug for FullstackContextInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FullstackContextInner")
.field("current_status", &self.current_status)
.field("response_headers", &self.response_headers)
.field("route_http_status", &self.route_http_status)
.finish()
}
}
impl PartialEq for FullstackContext {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.lock, &other.lock)
&& Arc::ptr_eq(&self.request_headers, &other.request_headers)
}
}
impl FullstackContext {
pub fn new(parts: Parts) -> Self {
Self {
request_headers: RwLock::new(parts).into(),
lock: RwLock::new(FullstackContextInner {
current_status: StreamingStatus::RenderingInitialChunk,
current_status_subscribers: Default::default(),
route_http_status: HttpError {
status: http::StatusCode::OK,
message: None,
},
route_http_status_subscribers: Default::default(),
response_headers: Some(HeaderMap::new()),
})
.into(),
}
}
pub fn commit_initial_chunk(&mut self) {
let mut lock = self.lock.write();
lock.current_status = StreamingStatus::InitialChunkCommitted;
#[allow(clippy::mutable_key_type)]
let subscribers = std::mem::take(&mut lock.current_status_subscribers);
for subscriber in subscribers {
subscriber.mark_dirty();
}
}
pub fn streaming_state(&self) -> StreamingStatus {
let mut lock = self.lock.write();
if let Some(ctx) = ReactiveContext::current() {
lock.current_status_subscribers.insert(ctx);
}
lock.current_status
}
pub fn parts_mut(&self) -> parking_lot::RwLockWriteGuard<'_, http::request::Parts> {
self.request_headers.write()
}
pub async fn scope<F, R>(self, fut: F) -> R
where
F: std::future::Future<Output = R>,
{
FULLSTACK_CONTEXT.scope(self, fut).await
}
pub fn extension<T: Clone + Send + Sync + 'static>(&self) -> Option<T> {
let lock = self.request_headers.read();
lock.extensions.get::<T>().cloned()
}
pub async fn extract<T: FromRequest<Self, M>, M>() -> Result<T, ServerFnError> {
let this = Self::current().unwrap_or_else(|| {
FullstackContext::new(
axum_core::extract::Request::builder()
.method("GET")
.uri("/")
.header("X-Dummy-Header", "true")
.body(())
.unwrap()
.into_parts()
.0,
)
});
let parts = this.request_headers.read().clone();
let request = axum_core::extract::Request::from_parts(parts, Default::default());
match T::from_request(request, &this).await {
Ok(res) => Ok(res),
Err(err) => {
let resp = err.into_response();
Err(ServerFnError::from_axum_response(resp).await)
}
}
}
pub fn current() -> Option<Self> {
if let Ok(context) = FULLSTACK_CONTEXT.try_get() {
return Some(context);
}
if let Some(rt) = dioxus_core::Runtime::try_current() {
let id = rt.try_current_scope_id()?;
if let Some(ctx) = rt.consume_context::<FullstackContext>(id) {
return Some(ctx);
}
}
None
}
pub fn current_http_status(&self) -> HttpError {
let mut lock = self.lock.write();
if let Some(ctx) = ReactiveContext::current() {
lock.route_http_status_subscribers.insert(ctx);
}
lock.route_http_status.clone()
}
pub fn set_current_http_status(&mut self, status: HttpError) {
let mut lock = self.lock.write();
lock.route_http_status = status;
#[allow(clippy::mutable_key_type)]
let subscribers = std::mem::take(&mut lock.route_http_status_subscribers);
for subscriber in subscribers {
subscriber.mark_dirty();
}
}
pub fn add_response_header(
&self,
key: impl Into<http::header::HeaderName>,
value: impl Into<http::header::HeaderValue>,
) {
let mut lock = self.lock.write();
if let Some(headers) = lock.response_headers.as_mut() {
headers.insert(key.into(), value.into());
}
}
pub fn take_response_headers(&self) -> Option<HeaderMap> {
let mut lock = self.lock.write();
lock.response_headers.take()
}
pub fn commit_http_status(status: StatusCode, message: Option<String>) {
if let Some(mut ctx) = Self::current() {
ctx.set_current_http_status(HttpError { status, message });
}
}
pub fn commit_error_status(error: impl Into<CapturedError>) -> HttpError {
let error = error.into();
let status = status_code_from_error(&error);
let http_error = HttpError {
status,
message: Some(error.to_string()),
};
if let Some(mut ctx) = Self::current() {
ctx.set_current_http_status(http_error.clone());
}
http_error
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum StreamingStatus {
RenderingInitialChunk,
InitialChunkCommitted,
}
pub fn commit_initial_chunk() {
crate::history::finalize_route();
if let Some(mut streaming) = FullstackContext::current() {
streaming.commit_initial_chunk();
}
}
#[deprecated(note = "Use FullstackContext::extract instead", since = "0.7.0")]
pub fn extract<T: FromRequest<FullstackContext, M>, M>(
) -> impl std::future::Future<Output = Result<T, ServerFnError>> {
FullstackContext::extract::<T, M>()
}
pub fn current_status() -> StreamingStatus {
if let Some(streaming) = FullstackContext::current() {
streaming.streaming_state()
} else {
StreamingStatus::InitialChunkCommitted
}
}
pub fn status_code_from_error(error: &CapturedError) -> StatusCode {
if let Some(err) = error.downcast_ref::<ServerFnError>() {
match err {
ServerFnError::ServerError { code, .. } => {
return StatusCode::from_u16(*code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
}
_ => return StatusCode::INTERNAL_SERVER_ERROR,
}
}
if let Some(err) = error.downcast_ref::<StatusCode>() {
return *err;
}
if let Some(err) = error.downcast_ref::<HttpError>() {
return err.status;
}
StatusCode::INTERNAL_SERVER_ERROR
}