poem/endpoint/
endpoint.rs

1use std::{future::Future, marker::PhantomData, sync::Arc};
2
3use futures_util::{FutureExt, future::BoxFuture};
4
5use super::{
6    After, AndThen, Around, Before, CatchAllError, CatchError, InspectAllError, InspectError, Map,
7    MapToResponse, ToResponse,
8};
9use crate::{
10    Error, IntoResponse, Middleware, Request, Response, Result,
11    error::IntoResult,
12    middleware::{AddData, AddDataEndpoint},
13};
14
15/// An HTTP request handler.
16pub trait Endpoint: Send + Sync {
17    /// Represents the response of the endpoint.
18    type Output: IntoResponse;
19
20    /// Get the response to the request.
21    fn call(&self, req: Request) -> impl Future<Output = Result<Self::Output>> + Send;
22
23    /// Get the response to the request and return a [`Response`].
24    ///
25    /// Unlike [`Endpoint::call`], when an error occurs, it will also convert
26    /// the error into a response object.
27    ///
28    /// # Example
29    ///
30    /// ```
31    /// use poem::{
32    ///     Endpoint, Request, Result, error::NotFoundError, handler, http::StatusCode,
33    ///     test::TestClient,
34    /// };
35    ///
36    /// #[handler]
37    /// fn index() -> Result<()> {
38    ///     Err(NotFoundError.into())
39    /// }
40    ///
41    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
42    /// TestClient::new(index)
43    ///     .get("/")
44    ///     .send()
45    ///     .await
46    ///     .assert_status(StatusCode::NOT_FOUND);
47    /// # });
48    /// ```
49    fn get_response(&self, req: Request) -> impl Future<Output = Response> + Send {
50        async move {
51            self.call(req)
52                .await
53                .map(IntoResponse::into_response)
54                .unwrap_or_else(|err| err.into_response())
55        }
56    }
57}
58
59struct SyncFnEndpoint<T, F> {
60    _mark: PhantomData<T>,
61    f: F,
62}
63
64impl<F, T, R> Endpoint for SyncFnEndpoint<T, F>
65where
66    F: Fn(Request) -> R + Send + Sync,
67    T: IntoResponse + Sync,
68    R: IntoResult<T>,
69{
70    type Output = T;
71
72    async fn call(&self, req: Request) -> Result<Self::Output> {
73        (self.f)(req).into_result()
74    }
75}
76
77struct AsyncFnEndpoint<T, F> {
78    _mark: PhantomData<T>,
79    f: F,
80}
81
82impl<F, Fut, T, R> Endpoint for AsyncFnEndpoint<T, F>
83where
84    F: Fn(Request) -> Fut + Sync + Send,
85    Fut: Future<Output = R> + Send,
86    T: IntoResponse + Sync,
87    R: IntoResult<T>,
88{
89    type Output = T;
90
91    async fn call(&self, req: Request) -> Result<Self::Output> {
92        (self.f)(req).await.into_result()
93    }
94}
95
96/// The enum `EitherEndpoint` with variants `Left` and `Right` is a general
97/// purpose sum type with two cases.
98pub enum EitherEndpoint<A, B> {
99    /// A endpoint of type `A`
100    A(A),
101    /// A endpoint of type `B`
102    B(B),
103}
104
105impl<A, B> Endpoint for EitherEndpoint<A, B>
106where
107    A: Endpoint,
108    B: Endpoint,
109{
110    type Output = Response;
111
112    async fn call(&self, req: Request) -> Result<Self::Output> {
113        match self {
114            EitherEndpoint::A(a) => a.call(req).await.map(IntoResponse::into_response),
115            EitherEndpoint::B(b) => b.call(req).await.map(IntoResponse::into_response),
116        }
117    }
118}
119
120/// Create an endpoint with a function.
121///
122/// The output can be any type that implements [`IntoResult`].
123///
124/// # Example
125///
126/// ```
127/// use poem::{Endpoint, Request, endpoint::make_sync, http::Method, test::TestClient};
128///
129/// let ep = make_sync(|req| req.method().to_string());
130/// let cli = TestClient::new(ep);
131///
132/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
133/// let resp = cli.get("/").send().await;
134/// resp.assert_status_is_ok();
135/// resp.assert_text("GET").await;
136/// # });
137/// ```
138pub fn make_sync<F, T, R>(f: F) -> impl Endpoint<Output = T>
139where
140    F: Fn(Request) -> R + Send + Sync,
141    T: IntoResponse + Sync,
142    R: IntoResult<T>,
143{
144    SyncFnEndpoint {
145        _mark: PhantomData,
146        f,
147    }
148}
149
150/// Create an endpoint with a asyncness function.
151///
152/// The output can be any type that implements [`IntoResult`].
153///
154/// # Example
155///
156/// ```
157/// use poem::{Endpoint, Request, endpoint::make, http::Method, test::TestClient};
158///
159/// let ep = make(|req| async move { req.method().to_string() });
160/// let app = TestClient::new(ep);
161///
162/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
163/// let resp = app.get("/").send().await;
164/// resp.assert_status_is_ok();
165/// resp.assert_text("GET").await;
166/// # });
167/// ```
168pub fn make<F, Fut, T, R>(f: F) -> impl Endpoint<Output = T>
169where
170    F: Fn(Request) -> Fut + Send + Sync,
171    Fut: Future<Output = R> + Send,
172    T: IntoResponse + Sync,
173    R: IntoResult<T>,
174{
175    AsyncFnEndpoint {
176        _mark: PhantomData,
177        f,
178    }
179}
180
181impl<T: Endpoint + ?Sized> Endpoint for &T {
182    type Output = T::Output;
183
184    async fn call(&self, req: Request) -> Result<Self::Output> {
185        T::call(self, req).await
186    }
187}
188
189impl<T: Endpoint + ?Sized> Endpoint for Box<T> {
190    type Output = T::Output;
191
192    async fn call(&self, req: Request) -> Result<Self::Output> {
193        self.as_ref().call(req).await
194    }
195}
196
197impl<T: Endpoint + ?Sized> Endpoint for Arc<T> {
198    type Output = T::Output;
199
200    async fn call(&self, req: Request) -> Result<Self::Output> {
201        self.as_ref().call(req).await
202    }
203}
204
205/// A `endpoint` that can be dynamically dispatched.
206pub trait DynEndpoint: Send + Sync {
207    /// Represents the response of the endpoint.
208    type Output: IntoResponse;
209
210    /// Get the response to the request.
211    fn call(&self, req: Request) -> BoxFuture<'_, Result<Self::Output>>;
212}
213
214/// A [`Endpoint`] wrapper used to implement [`DynEndpoint`].
215pub struct ToDynEndpoint<E>(pub E);
216
217impl<E> DynEndpoint for ToDynEndpoint<E>
218where
219    E: Endpoint,
220{
221    type Output = E::Output;
222
223    #[inline]
224    fn call(&self, req: Request) -> BoxFuture<'_, Result<Self::Output>> {
225        self.0.call(req).boxed()
226    }
227}
228
229impl<T> Endpoint for dyn DynEndpoint<Output = T> + '_
230where
231    T: IntoResponse,
232{
233    type Output = T;
234
235    #[inline]
236    async fn call(&self, req: Request) -> Result<Self::Output> {
237        DynEndpoint::call(self, req).await
238    }
239}
240
241/// An owned dynamically typed `Endpoint` for use in cases where you can’t
242/// statically type your result or need to add some indirection.
243pub type BoxEndpoint<'a, T = Response> = Box<dyn DynEndpoint<Output = T> + 'a>;
244
245/// Extension trait for [`Endpoint`].
246pub trait EndpointExt: IntoEndpoint {
247    /// Wrap the endpoint in a Box.
248    fn boxed<'a>(self) -> BoxEndpoint<'a, <Self::Endpoint as Endpoint>::Output>
249    where
250        Self: Sized + 'a,
251    {
252        Box::new(ToDynEndpoint(self.into_endpoint()))
253    }
254
255    /// Use middleware to transform this endpoint.
256    ///
257    /// # Example
258    ///
259    /// ```
260    /// use poem::{
261    ///     Endpoint, EndpointExt, Request, Route, get, handler, http::StatusCode, middleware::AddData,
262    ///     test::TestClient, web::Data,
263    /// };
264    ///
265    /// #[handler]
266    /// async fn index(Data(data): Data<&i32>) -> String {
267    ///     format!("{}", data)
268    /// }
269    ///
270    /// let app = Route::new().at("/", get(index)).with(AddData::new(100i32));
271    /// let cli = TestClient::new(app);
272    ///
273    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
274    /// let resp = cli.get("/").send().await;
275    /// resp.assert_status_is_ok();
276    /// resp.assert_text("100").await;
277    /// # });
278    /// ```
279    fn with<T>(self, middleware: T) -> T::Output
280    where
281        T: Middleware<Self::Endpoint>,
282        Self: Sized,
283    {
284        middleware.transform(self.into_endpoint())
285    }
286
287    /// if `enable` is `true` then use middleware to transform this endpoint.
288    ///
289    /// # Example
290    ///
291    /// ```
292    /// use poem::{
293    ///     Endpoint, EndpointExt, Request, Route, get, handler,
294    ///     http::{StatusCode, Uri},
295    ///     middleware::AddData,
296    ///     test::TestClient,
297    ///     web::Data,
298    /// };
299    ///
300    /// #[handler]
301    /// async fn index(data: Option<Data<&i32>>) -> String {
302    ///     match data {
303    ///         Some(data) => data.0.to_string(),
304    ///         None => "none".to_string(),
305    ///     }
306    /// }
307    ///
308    /// let app = Route::new()
309    ///     .at("/a", get(index).with_if(true, AddData::new(100i32)))
310    ///     .at("/b", get(index).with_if(false, AddData::new(100i32)));
311    /// let cli = TestClient::new(app);
312    ///
313    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
314    /// let resp = cli.get("/a").send().await;
315    /// resp.assert_status_is_ok();
316    /// resp.assert_text("100").await;
317    ///
318    /// let resp = cli.get("/b").send().await;
319    /// resp.assert_status_is_ok();
320    /// resp.assert_text("none").await;
321    /// # });
322    /// ```
323    fn with_if<T>(self, enable: bool, middleware: T) -> EitherEndpoint<Self, T::Output>
324    where
325        T: Middleware<Self::Endpoint>,
326        Self: Sized,
327    {
328        if !enable {
329            EitherEndpoint::A(self)
330        } else {
331            EitherEndpoint::B(middleware.transform(self.into_endpoint()))
332        }
333    }
334
335    /// Attach a state data to the endpoint, similar to `with(AddData(T))`.
336    ///
337    /// # Example
338    ///
339    /// ```
340    /// use poem::{
341    ///     Endpoint, EndpointExt, Request, handler, http::StatusCode, test::TestClient, web::Data,
342    /// };
343    ///
344    /// #[handler]
345    /// async fn index(data: Data<&i32>) -> String {
346    ///     format!("{}", data.0)
347    /// }
348    ///
349    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
350    /// let resp = TestClient::new(index.data(100i32)).get("/").send().await;
351    /// resp.assert_status_is_ok();
352    /// resp.assert_text("100").await;
353    /// # });
354    /// ```
355    fn data<T>(self, data: T) -> AddDataEndpoint<Self::Endpoint, T>
356    where
357        T: Clone + Send + Sync + 'static,
358        Self: Sized,
359    {
360        self.with(AddData::new(data))
361    }
362
363    /// if `data` is `Some(T)` then attach the value to the endpoint.
364    fn data_opt<T>(
365        self,
366        data: Option<T>,
367    ) -> EitherEndpoint<AddDataEndpoint<Self::Endpoint, T>, Self>
368    where
369        T: Clone + Send + Sync + 'static,
370        Self: Sized,
371    {
372        match data {
373            Some(data) => EitherEndpoint::A(AddData::new(data).transform(self.into_endpoint())),
374            None => EitherEndpoint::B(self),
375        }
376    }
377
378    /// Maps the request of this endpoint.
379    ///
380    /// # Example
381    ///
382    /// ```
383    /// use poem::{
384    ///     Endpoint, EndpointExt, Error, Request, Result, handler, http::StatusCode, test::TestClient,
385    /// };
386    ///
387    /// #[handler]
388    /// async fn index(data: String) -> String {
389    ///     data
390    /// }
391    ///
392    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
393    /// let mut resp = index
394    ///     .before(|mut req| async move {
395    ///         req.set_body("abc");
396    ///         Ok(req)
397    ///     })
398    ///     .call(Request::default())
399    ///     .await
400    ///     .unwrap();
401    /// assert_eq!(resp.take_body().into_string().await.unwrap(), "abc");
402    /// # });
403    /// ```
404    fn before<F, Fut>(self, f: F) -> Before<Self, F>
405    where
406        F: Fn(Request) -> Fut + Send + Sync,
407        Fut: Future<Output = Result<Request>> + Send,
408        Self: Sized,
409    {
410        Before::new(self, f)
411    }
412
413    /// Maps the output of this endpoint.
414    ///
415    /// # Example
416    ///
417    /// ```
418    /// use poem::{Endpoint, EndpointExt, Error, Request, Result, handler, http::StatusCode};
419    ///
420    /// #[handler]
421    /// async fn index() -> &'static str {
422    ///     "abc"
423    /// }
424    ///
425    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
426    /// let mut resp = index
427    ///     .after(|res| async move {
428    ///         match res {
429    ///             Ok(resp) => Ok(resp.into_body().into_string().await.unwrap() + "def"),
430    ///             Err(err) => Err(err),
431    ///         }
432    ///     })
433    ///     .call(Request::default())
434    ///     .await
435    ///     .unwrap();
436    /// assert_eq!(resp, "abcdef");
437    /// # });
438    /// ```
439    fn after<F, Fut, T>(self, f: F) -> After<Self::Endpoint, F>
440    where
441        F: Fn(Result<<Self::Endpoint as Endpoint>::Output>) -> Fut + Send + Sync,
442        Fut: Future<Output = Result<T>> + Send,
443        T: IntoResponse,
444        Self: Sized,
445    {
446        After::new(self.into_endpoint(), f)
447    }
448
449    /// Maps the request and response of this endpoint.
450    ///
451    /// # Example
452    ///
453    /// ```
454    /// use poem::{
455    ///     Endpoint, EndpointExt, Error, Request, Result, handler,
456    ///     http::{HeaderMap, HeaderValue, StatusCode},
457    /// };
458    ///
459    /// #[handler]
460    /// async fn index(headers: &HeaderMap) -> String {
461    ///     headers
462    ///         .get("x-value")
463    ///         .and_then(|value| value.to_str().ok())
464    ///         .unwrap()
465    ///         .to_string()
466    ///         + ","
467    /// }
468    ///
469    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
470    /// let mut resp = index
471    ///     .around(|ep, mut req| async move {
472    ///         req.headers_mut()
473    ///             .insert("x-value", HeaderValue::from_static("hello"));
474    ///         let mut resp = ep.call(req).await?;
475    ///         Ok(resp.take_body().into_string().await.unwrap() + "world")
476    ///     })
477    ///     .call(Request::default())
478    ///     .await
479    ///     .unwrap();
480    /// assert_eq!(resp, "hello,world");
481    /// # });
482    /// ```
483    fn around<F, Fut, R>(self, f: F) -> Around<Self::Endpoint, F>
484    where
485        F: Fn(Arc<Self::Endpoint>, Request) -> Fut + Send + Sync + 'static,
486        Fut: Future<Output = Result<R>> + Send + 'static,
487        R: IntoResponse,
488        Self: Sized,
489    {
490        Around::new(self.into_endpoint(), f)
491    }
492
493    /// Convert the output of this endpoint into a response.
494    /// [`Response`](crate::Response).
495    ///
496    /// # Example
497    ///
498    /// ```
499    /// use poem::{
500    ///     Endpoint, EndpointExt, Error, Request, Response, Result, endpoint::make, http::StatusCode,
501    /// };
502    ///
503    /// let ep1 = make(|_| async { "hello" }).map_to_response();
504    /// let ep2 = make(|_| async { Err::<(), Error>(Error::from_status(StatusCode::BAD_REQUEST)) })
505    ///     .map_to_response();
506    ///
507    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
508    /// let resp = ep1.call(Request::default()).await.unwrap();
509    /// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello");
510    ///
511    /// let err = ep2.call(Request::default()).await.unwrap_err();
512    /// assert_eq!(err.into_response().status(), StatusCode::BAD_REQUEST);
513    /// # });
514    /// ```
515    fn map_to_response(self) -> MapToResponse<Self::Endpoint>
516    where
517        Self: Sized,
518    {
519        MapToResponse::new(self.into_endpoint())
520    }
521
522    /// Convert the output of this endpoint into a response.
523    /// [`Response`](crate::Response).
524    ///
525    /// NOTE: Unlike [`EndpointExt::map_to_response`], when an error occurs, it
526    /// will also convert the error into a response object, so this endpoint
527    /// will just returns `Ok(Response)`.
528    ///
529    /// # Example
530    ///
531    /// ```
532    /// use poem::{
533    ///     Endpoint, EndpointExt, Error, Request, Response, Result, endpoint::make, http::StatusCode,
534    /// };
535    ///
536    /// let ep1 = make(|_| async { "hello" }).to_response();
537    /// let ep2 = make(|_| async { Err::<(), Error>(Error::from_status(StatusCode::BAD_REQUEST)) })
538    ///     .to_response();
539    ///
540    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
541    /// let resp = ep1.call(Request::default()).await.unwrap();
542    /// assert_eq!(resp.into_body().into_string().await.unwrap(), "hello");
543    ///
544    /// let resp = ep2.call(Request::default()).await.unwrap();
545    /// assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
546    /// # });
547    /// ```
548    fn to_response(self) -> ToResponse<Self::Endpoint>
549    where
550        Self: Sized,
551    {
552        ToResponse::new(self.into_endpoint())
553    }
554
555    /// Maps the response of this endpoint.
556    ///
557    /// # Example
558    ///
559    /// ```
560    /// use poem::{
561    ///     Endpoint, EndpointExt, Error, Request, Response, Result, endpoint::make, http::StatusCode,
562    /// };
563    ///
564    /// let ep = make(|_| async { "hello" }).map(|value| async move { format!("{}, world!", value) });
565    ///
566    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
567    /// let mut resp: String = ep.call(Request::default()).await.unwrap();
568    /// assert_eq!(resp, "hello, world!");
569    /// # });
570    /// ```
571    fn map<F, Fut, R, R2>(self, f: F) -> Map<Self::Endpoint, F>
572    where
573        F: Fn(R) -> Fut + Send + Sync,
574        Fut: Future<Output = R2> + Send,
575        R: IntoResponse,
576        R2: IntoResponse,
577        Self: Sized,
578        Self::Endpoint: Endpoint<Output = R> + Sized,
579    {
580        Map::new(self.into_endpoint(), f)
581    }
582
583    /// Calls `f` if the result is `Ok`, otherwise returns the `Err` value of
584    /// self.
585    ///
586    /// # Example
587    ///
588    /// ```
589    /// use poem::{
590    ///     Endpoint, EndpointExt, Error, Request, Response, Result, endpoint::make, http::StatusCode,
591    /// };
592    ///
593    /// let ep1 = make(|_| async { "hello" })
594    ///     .and_then(|value| async move { Ok(format!("{}, world!", value)) });
595    /// let ep2 = make(|_| async { Err::<String, _>(Error::from_status(StatusCode::BAD_REQUEST)) })
596    ///     .and_then(|value| async move { Ok(format!("{}, world!", value)) });
597    ///
598    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
599    /// let resp: String = ep1.call(Request::default()).await.unwrap();
600    /// assert_eq!(resp, "hello, world!");
601    ///
602    /// let err: Error = ep2.call(Request::default()).await.unwrap_err();
603    /// assert_eq!(err.into_response().status(), StatusCode::BAD_REQUEST);
604    /// # });
605    /// ```
606    fn and_then<F, Fut, R, R2>(self, f: F) -> AndThen<Self::Endpoint, F>
607    where
608        F: Fn(R) -> Fut + Send + Sync,
609        Fut: Future<Output = Result<R2>> + Send,
610        R: IntoResponse,
611        R2: IntoResponse,
612        Self: Sized,
613        Self::Endpoint: Endpoint<Output = R> + Sized,
614    {
615        AndThen::new(self.into_endpoint(), f)
616    }
617
618    /// Catch all errors and convert it into a response.
619    ///
620    /// # Example
621    ///
622    /// ```
623    /// use http::Uri;
624    /// use poem::{
625    ///     Endpoint, EndpointExt, Error, IntoResponse, Request, Response, Route, handler,
626    ///     http::StatusCode, web::Json,
627    /// };
628    /// use serde::Serialize;
629    ///
630    /// #[handler]
631    /// async fn index() {}
632    ///
633    /// let app = Route::new()
634    ///     .at("/index", index)
635    ///     .catch_all_error(custom_error);
636    ///
637    /// #[derive(Serialize)]
638    /// struct ErrorResponse {
639    ///     message: String,
640    /// }
641    ///
642    /// async fn custom_error(err: Error) -> impl IntoResponse {
643    ///     Json(ErrorResponse {
644    ///         message: err.to_string(),
645    ///     })
646    /// }
647    ///
648    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
649    /// let resp = app
650    ///     .call(Request::builder().uri(Uri::from_static("/abc")).finish())
651    ///     .await
652    ///     .unwrap();
653    /// assert_eq!(resp.status(), StatusCode::OK);
654    /// assert_eq!(
655    ///     resp.into_body().into_string().await.unwrap(),
656    ///     "{\"message\":\"not found\"}"
657    /// );
658    /// # })
659    /// ```
660    fn catch_all_error<F, Fut, R>(self, f: F) -> CatchAllError<Self, F, R>
661    where
662        F: Fn(Error) -> Fut + Send + Sync,
663        Fut: Future<Output = R> + Send,
664        R: IntoResponse + Send,
665        Self: Sized + Sync,
666    {
667        CatchAllError::new(self, f)
668    }
669
670    /// Catch the specified type of error and convert it into a response.
671    ///
672    /// # Example
673    ///
674    /// ```
675    /// use http::Uri;
676    /// use poem::{
677    ///     Endpoint, EndpointExt, IntoResponse, Request, Response, Route, error::NotFoundError,
678    ///     handler, http::StatusCode,
679    /// };
680    ///
681    /// #[handler]
682    /// async fn index() {}
683    ///
684    /// let app = Route::new().at("/index", index).catch_error(custom_404);
685    ///
686    /// async fn custom_404(_: NotFoundError) -> impl IntoResponse {
687    ///     "custom not found".with_status(StatusCode::NOT_FOUND)
688    /// }
689    ///
690    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
691    ///
692    /// let resp = app
693    ///     .call(Request::builder().uri(Uri::from_static("/abc")).finish())
694    ///     .await
695    ///     .unwrap();
696    /// assert_eq!(resp.status(), StatusCode::NOT_FOUND);
697    /// assert_eq!(
698    ///     resp.into_body().into_string().await.unwrap(),
699    ///     "custom not found"
700    /// );
701    /// # })
702    /// ```
703    fn catch_error<F, Fut, R, ErrType>(self, f: F) -> CatchError<Self, F, R, ErrType>
704    where
705        F: Fn(ErrType) -> Fut + Send + Sync,
706        Fut: Future<Output = R> + Send,
707        R: IntoResponse + Send + Sync,
708        ErrType: std::error::Error + Send + Sync + 'static,
709        Self: Sized,
710    {
711        CatchError::new(self, f)
712    }
713
714    /// Does something with each error.
715    ///
716    /// # Example
717    ///
718    /// ```
719    /// use poem::{EndpointExt, Route, handler};
720    ///
721    /// #[handler]
722    /// fn index() {}
723    ///
724    /// let app = Route::new().at("/", index).inspect_all_err(|err| {
725    ///     println!("error: {}", err);
726    /// });
727    /// ```
728    fn inspect_all_err<F>(self, f: F) -> InspectAllError<Self, F>
729    where
730        F: Fn(&Error) + Send + Sync,
731        Self: Sized,
732    {
733        InspectAllError::new(self, f)
734    }
735
736    /// Does something with each specified error type.
737    ///
738    /// # Example
739    ///
740    /// ```
741    /// use poem::{EndpointExt, Route, error::NotFoundError, handler};
742    ///
743    /// #[handler]
744    /// fn index() {}
745    ///
746    /// let app = Route::new()
747    ///     .at("/", index)
748    ///     .inspect_err(|err: &NotFoundError| {
749    ///         println!("error: {}", err);
750    ///     });
751    /// ```
752    fn inspect_err<F, ErrType>(self, f: F) -> InspectError<Self, F, ErrType>
753    where
754        F: Fn(&ErrType) + Send + Sync,
755        ErrType: std::error::Error + Send + Sync + 'static,
756        Self: Sized,
757    {
758        InspectError::new(self, f)
759    }
760}
761
762impl<T: IntoEndpoint> EndpointExt for T {}
763
764/// Represents a type that can convert into endpoint.
765pub trait IntoEndpoint {
766    /// Represents the endpoint type.
767    type Endpoint: Endpoint;
768
769    /// Converts this object into endpoint.
770    fn into_endpoint(self) -> Self::Endpoint;
771}
772
773impl<T: Endpoint> IntoEndpoint for T {
774    type Endpoint = T;
775
776    fn into_endpoint(self) -> Self::Endpoint {
777        self
778    }
779}
780
781#[cfg(test)]
782mod test {
783    use http::{HeaderValue, Uri};
784
785    use crate::{
786        Endpoint, EndpointExt, Error, IntoEndpoint, Request, Route,
787        endpoint::{make, make_sync},
788        get, handler,
789        http::{Method, StatusCode},
790        middleware::SetHeader,
791        test::TestClient,
792        web::Data,
793    };
794
795    #[tokio::test]
796    async fn test_make() {
797        let ep = make(|req| async move { format!("method={}", req.method()) }).map_to_response();
798        let mut resp = ep
799            .call(Request::builder().method(Method::DELETE).finish())
800            .await
801            .unwrap();
802        assert_eq!(
803            resp.take_body().into_string().await.unwrap(),
804            "method=DELETE"
805        );
806    }
807
808    #[tokio::test]
809    async fn test_before() {
810        assert_eq!(
811            make_sync(|req| req.method().to_string())
812                .before(|mut req| async move {
813                    req.set_method(Method::POST);
814                    Ok(req)
815                })
816                .call(Request::default())
817                .await
818                .unwrap(),
819            "POST"
820        );
821    }
822
823    #[tokio::test]
824    async fn test_after() {
825        assert_eq!(
826            make_sync(|_| "abc")
827                .after(|_| async { Ok::<_, Error>("def") })
828                .call(Request::default())
829                .await
830                .unwrap(),
831            "def"
832        );
833    }
834
835    #[tokio::test]
836    async fn test_map_to_response() {
837        assert_eq!(
838            make_sync(|_| "abc")
839                .map_to_response()
840                .call(Request::default())
841                .await
842                .unwrap()
843                .take_body()
844                .into_string()
845                .await
846                .unwrap(),
847            "abc"
848        );
849    }
850
851    #[tokio::test]
852    async fn test_and_then() {
853        assert_eq!(
854            make_sync(|_| "abc")
855                .and_then(|resp| async move { Ok(resp.to_string() + "def") })
856                .call(Request::default())
857                .await
858                .unwrap(),
859            "abcdef"
860        );
861
862        let resp = make_sync(|_| Err::<String, _>(Error::from_status(StatusCode::BAD_REQUEST)))
863            .and_then(|resp| async move { Ok(resp + "def") })
864            .get_response(Request::default())
865            .await;
866        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
867    }
868
869    #[tokio::test]
870    async fn test_map() {
871        assert_eq!(
872            make_sync(|_| "abc")
873                .map(|resp| async move { resp.to_string() + "def" })
874                .call(Request::default())
875                .await
876                .unwrap(),
877            "abcdef"
878        );
879
880        let resp = make_sync(|_| Err::<String, _>(Error::from_status(StatusCode::BAD_REQUEST)))
881            .map(|resp| async move { resp.to_string() + "def" })
882            .get_response(Request::default())
883            .await;
884        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
885    }
886
887    #[tokio::test]
888    async fn test_around() {
889        let ep = make(|req| async move { req.into_body().into_string().await.unwrap() + "b" });
890        assert_eq!(
891            ep.around(|ep, mut req| async move {
892                req.set_body("a");
893                let resp = ep.call(req).await?;
894                Ok(resp + "c")
895            })
896            .call(Request::default())
897            .await
898            .unwrap(),
899            "abc"
900        );
901    }
902
903    #[tokio::test]
904    async fn test_with_if() {
905        let resp = make_sync(|_| ())
906            .with_if(true, SetHeader::new().appending("a", 1))
907            .call(Request::default())
908            .await
909            .unwrap();
910        assert_eq!(
911            resp.headers().get("a"),
912            Some(&HeaderValue::from_static("1"))
913        );
914
915        let resp = make_sync(|_| ())
916            .with_if(false, SetHeader::new().appending("a", 1))
917            .call(Request::default())
918            .await
919            .unwrap();
920        assert_eq!(resp.headers().get("a"), None);
921    }
922
923    #[tokio::test]
924    async fn test_into_endpoint() {
925        struct MyEndpointFactory;
926
927        impl IntoEndpoint for MyEndpointFactory {
928            type Endpoint = Route;
929
930            fn into_endpoint(self) -> Self::Endpoint {
931                Route::new()
932                    .at("/a", get(make_sync(|_| "a")))
933                    .at("/b", get(make_sync(|_| "b")))
934            }
935        }
936
937        let app = Route::new().nest("/api", MyEndpointFactory);
938
939        assert_eq!(
940            app.call(Request::builder().uri(Uri::from_static("/api/a")).finish())
941                .await
942                .unwrap()
943                .take_body()
944                .into_string()
945                .await
946                .unwrap(),
947            "a"
948        );
949
950        assert_eq!(
951            app.call(Request::builder().uri(Uri::from_static("/api/b")).finish())
952                .await
953                .unwrap()
954                .take_body()
955                .into_string()
956                .await
957                .unwrap(),
958            "b"
959        );
960    }
961
962    #[tokio::test]
963    async fn test_data_opt() {
964        #[handler(internal)]
965        async fn index(data: Option<Data<&i32>>) -> String {
966            match data.as_deref() {
967                Some(value) => format!("{value}"),
968                None => "none".to_string(),
969            }
970        }
971
972        let cli = TestClient::new(index.data_opt(Some(100)));
973        let resp = cli.get("/").send().await;
974        resp.assert_status_is_ok();
975        resp.assert_text("100").await;
976
977        let cli = TestClient::new(index.data_opt(None::<i32>));
978        let resp = cli.get("/").send().await;
979        resp.assert_status_is_ok();
980        resp.assert_text("none").await;
981    }
982}