use std::{marker::PhantomData, ops::Deref};
use crate::http::Payload;
use crate::web::error::{ErrorRenderer, StateExtractorError};
use crate::web::extract::FromRequest;
use crate::web::httprequest::HttpRequest;
use crate::web::service::AppState;
#[derive(Debug)]
pub struct State<T>(AppState, PhantomData<T>);
impl<T: 'static> State<T> {
pub fn get_ref(&self) -> &T {
self.0.get::<T>().expect("Unexpected state")
}
}
impl<T: 'static> Deref for State<T> {
type Target = T;
fn deref(&self) -> &T {
self.get_ref()
}
}
impl<T> Clone for State<T> {
fn clone(&self) -> State<T> {
State(self.0.clone(), PhantomData)
}
}
impl<T: 'static, E: ErrorRenderer> FromRequest<E> for State<T> {
type Error = StateExtractorError;
#[inline]
async fn from_request(req: &HttpRequest, _: &mut Payload) -> Result<Self, Self::Error> {
if req.0.app_state.contains::<T>() {
Ok(Self(req.0.app_state.clone(), PhantomData))
} else {
log::debug!(
"Failed to construct App-level State extractor. \
Request path: {:?}",
req.path()
);
Err(StateExtractorError::NotConfigured)
}
}
}
#[cfg(test)]
mod tests {
use crate::web::test::{TestRequest, init_service};
use crate::web::{self, App, HttpResponse};
use crate::{http::StatusCode, time::Millis, time::sleep};
#[crate::rt_test]
async fn test_state_extractor() {
let srv = init_service(
App::new().state(10usize).service(
web::resource("/")
.to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
),
)
.await;
let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let srv = init_service(
App::new().state(10u32).service(
web::resource("/")
.to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
),
)
.await;
let req = TestRequest::default().to_request();
let res = srv.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[crate::rt_test]
async fn test_state_drop() {
use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
struct TestData(Arc<AtomicUsize>);
impl TestData {
fn new(inner: Arc<AtomicUsize>) -> Self {
let _ = inner.fetch_add(1, Ordering::SeqCst);
Self(inner)
}
}
impl Clone for TestData {
fn clone(&self) -> Self {
let inner = self.0.clone();
let _ = inner.fetch_add(1, Ordering::SeqCst);
Self(inner)
}
}
impl Drop for TestData {
fn drop(&mut self) {
let _ = self.0.fetch_sub(1, Ordering::SeqCst);
}
}
let num = Arc::new(AtomicUsize::new(0));
let data = TestData::new(num.clone());
assert_eq!(num.load(Ordering::SeqCst), 1);
{
let srv = web::test::server(async move || {
let data = data.clone();
App::new().state(data).service(
web::resource("/").to(|_data: super::State<TestData>| async { "ok" }),
)
})
.await;
assert!(srv.get("/").send().await.unwrap().status().is_success());
srv.stop().await;
sleep(Millis(500)).await;
}
assert!(num.load(Ordering::SeqCst) <= 1);
}
#[crate::rt_test]
async fn test_route_state_extractor() {
let srv =
init_service(App::new().service(web::resource("/").state(10usize).route(
web::get().to(|data: web::types::State<usize>| async move {
let _ = data.clone();
HttpResponse::Ok()
}),
)))
.await;
let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let srv = init_service(App::new().service(web::resource("/").state(10u32).route(
web::get().to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
)))
.await;
let req = TestRequest::default().to_request();
let res = srv.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[crate::rt_test]
async fn test_override_state() {
let srv = init_service(App::new().state(1usize).service(
web::resource("/").state(10usize).route(web::get().to(
|data: web::types::State<usize>| async move {
assert_eq!(*data, 10);
let _ = data.clone();
HttpResponse::Ok()
},
)),
))
.await;
let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
}