1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
use std::ops::Deref;
use crate::{error::GetDataError, FromRequest, Request, RequestBody, Result};
/// An extractor that can extract data from the request extension.
///
/// # Errors
///
/// - [`GetDataError`]
///
/// # Example
///
/// ```
/// use poem::{
/// get, handler, http::StatusCode, middleware::AddData, web::Data, Endpoint, EndpointExt,
/// Request, Route,
/// };
///
/// #[handler]
/// async fn index(data: Data<&i32>) {
/// assert_eq!(*data.0, 10);
/// }
///
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let app = Route::new().at("/", get(index)).data(10i32);
/// let resp = app.get_response(Request::default()).await;
/// assert_eq!(resp.status(), StatusCode::OK);
/// # });
/// ```
pub struct Data<T>(pub T);
impl<T> Deref for Data<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'a, T: Send + Sync + 'static> FromRequest<'a> for Data<&'a T> {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(Data(
req.extensions()
.get::<T>()
.ok_or_else(|| GetDataError(std::any::type_name::<T>()))?,
))
}
}
#[cfg(test)]
mod tests {
use http::StatusCode;
use super::*;
use crate::{handler, middleware::AddData, test::TestClient, EndpointExt};
#[tokio::test]
async fn test_data_extractor() {
#[handler(internal)]
async fn index(value: Data<&i32>) {
assert_eq!(value.0, &100);
}
let app = index.with(AddData::new(100i32));
TestClient::new(app)
.get("/")
.send()
.await
.assert_status_is_ok();
}
#[tokio::test]
async fn test_data_extractor_error() {
#[handler(internal)]
async fn index(_value: Data<&i32>) {
todo!()
}
TestClient::new(index)
.get("/")
.send()
.await
.assert_status(StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn test_data_extractor_deref() {
#[handler(internal)]
async fn index(value: Data<&String>) {
assert_eq!(value.to_uppercase(), "ABC");
}
TestClient::new(index.with(AddData::new("abc".to_string())))
.get("/")
.send()
.await
.assert_status_is_ok();
}
}