Skip to main content

warp/
test.rs

1//! Test utilities to test your filters.
2//!
3//! [`Filter`](../trait.Filter.html)s can be easily tested without starting up an HTTP
4//! server, by making use of the [`RequestBuilder`](./struct.RequestBuilder.html) in this
5//! module.
6//!
7//! # Testing Filters
8//!
9//! It's easy to test filters, especially if smaller filters are used to build
10//! up your full set. Consider these example filters:
11//!
12//! ```
13//! use warp::Filter;
14//!
15//! fn sum() -> impl Filter<Extract = (u32,), Error = warp::Rejection> + Copy {
16//!     warp::path::param()
17//!         .and(warp::path::param())
18//!         .map(|x: u32, y: u32| {
19//!             x + y
20//!         })
21//! }
22//!
23//! fn math() -> impl Filter<Extract = (String,), Error = warp::Rejection> + Copy {
24//!     warp::post()
25//!         .and(sum())
26//!         .map(|z: u32| {
27//!             format!("Sum = {}", z)
28//!         })
29//! }
30//! ```
31//!
32//! We can test some requests against the `sum` filter like this:
33//!
34//! ```
35//! # use warp::Filter;
36//! #[tokio::test]
37//! async fn test_sum() {
38//! #    let sum = || warp::any().map(|| 3);
39//!     let filter = sum();
40//!
41//!     // Execute `sum` and get the `Extract` back.
42//!     let value = warp::test::request()
43//!         .path("/1/2")
44//!         .filter(&filter)
45//!         .await
46//!         .unwrap();
47//!     assert_eq!(value, 3);
48//!
49//!     // Or simply test if a request matches (doesn't reject).
50//!     assert!(
51//!         warp::test::request()
52//!             .path("/1/-5")
53//!             .matches(&filter)
54//!             .await
55//!     );
56//! }
57//! ```
58//!
59//! If the filter returns something that implements `Reply`, and thus can be
60//! turned into a response sent back to the client, we can test what exact
61//! response is returned. The `math` filter uses the `sum` filter, but returns
62//! a `String` that can be turned into a response.
63//!
64//! ```
65//! # use warp::Filter;
66//! #[test]
67//! fn test_math() {
68//! #    let math = || warp::any().map(warp::reply);
69//!     let filter = math();
70//!
71//!     let res = warp::test::request()
72//!         .path("/1/2")
73//!         .reply(&filter);
74//!     assert_eq!(res.status(), 405, "GET is not allowed");
75//!
76//!     let res = warp::test::request()
77//!         .method("POST")
78//!         .path("/1/2")
79//!         .reply(&filter);
80//!     assert_eq!(res.status(), 200);
81//!     assert_eq!(res.body(), "Sum is 3");
82//! }
83//! ```
84use std::convert::TryFrom;
85use std::error::Error as StdError;
86use std::fmt;
87use std::future::Future;
88use std::net::SocketAddr;
89#[cfg(feature = "websocket")]
90use std::pin::Pin;
91#[cfg(feature = "websocket")]
92use std::task::Context;
93#[cfg(feature = "websocket")]
94use std::task::Poll;
95
96use bytes::Bytes;
97#[cfg(feature = "websocket")]
98use futures_channel::mpsc;
99#[cfg(feature = "websocket")]
100use futures_util::StreamExt;
101use futures_util::{future, FutureExt, TryFutureExt};
102use http::{
103    header::{HeaderName, HeaderValue},
104    Response,
105};
106use http_body_util::BodyExt;
107use serde::Serialize;
108#[cfg(feature = "websocket")]
109use tokio::sync::oneshot;
110
111use crate::filter::Filter;
112use crate::filters::addr::RemoteAddr;
113#[cfg(feature = "websocket")]
114use crate::filters::ws::Message;
115use crate::reject::IsReject;
116use crate::reply::Reply;
117use crate::route::{self, Route};
118use crate::Request;
119#[cfg(feature = "websocket")]
120use crate::{Sink, Stream};
121
122use self::inner::OneOrTuple;
123
124/// Starts a new test `RequestBuilder`.
125pub fn request() -> RequestBuilder {
126    RequestBuilder {
127        req: Request::default(),
128    }
129}
130
131/// Starts a new test `WsBuilder`.
132#[cfg(feature = "websocket")]
133pub fn ws() -> WsBuilder {
134    WsBuilder { req: request() }
135}
136
137/// A request builder for testing filters.
138///
139/// See [module documentation](crate::test) for an overview.
140#[must_use = "RequestBuilder does nothing on its own"]
141#[derive(Debug)]
142pub struct RequestBuilder {
143    req: Request,
144}
145
146/// A Websocket builder for testing filters.
147///
148/// See [module documentation](crate::test) for an overview.
149#[cfg(feature = "websocket")]
150#[must_use = "WsBuilder does nothing on its own"]
151#[derive(Debug)]
152pub struct WsBuilder {
153    req: RequestBuilder,
154}
155
156/// A test client for Websocket filters.
157#[cfg(feature = "websocket")]
158pub struct WsClient {
159    tx: mpsc::UnboundedSender<crate::ws::Message>,
160    rx: mpsc::UnboundedReceiver<Result<crate::ws::Message, crate::error::Error>>,
161}
162
163/// An error from Websocket filter tests.
164#[derive(Debug)]
165pub struct WsError {
166    cause: Box<dyn StdError + Send + Sync>,
167}
168
169impl RequestBuilder {
170    /// Sets the method of this builder.
171    ///
172    /// The default if not set is `GET`.
173    ///
174    /// # Example
175    ///
176    /// ```
177    /// let req = warp::test::request()
178    ///     .method("POST");
179    /// ```
180    ///
181    /// # Panic
182    ///
183    /// This panics if the passed string is not able to be parsed as a valid
184    /// `Method`.
185    pub fn method(mut self, method: &str) -> Self {
186        *self.req.method_mut() = method.parse().expect("valid method");
187        self
188    }
189
190    /// Sets the request path of this builder.
191    ///
192    /// The default is not set is `/`.
193    ///
194    /// # Example
195    ///
196    /// ```
197    /// let req = warp::test::request()
198    ///     .path("/todos/33");
199    /// ```
200    ///
201    /// # Panic
202    ///
203    /// This panics if the passed string is not able to be parsed as a valid
204    /// `Uri`.
205    pub fn path(mut self, p: &str) -> Self {
206        let uri = p.parse().expect("test request path invalid");
207        *self.req.uri_mut() = uri;
208        self
209    }
210
211    /// Set a header for this request.
212    ///
213    /// # Example
214    ///
215    /// ```
216    /// let req = warp::test::request()
217    ///     .header("accept", "application/json");
218    /// ```
219    ///
220    /// # Panic
221    ///
222    /// This panics if the passed strings are not able to be parsed as a valid
223    /// `HeaderName` and `HeaderValue`.
224    pub fn header<K, V>(mut self, key: K, value: V) -> Self
225    where
226        HeaderName: TryFrom<K>,
227        HeaderValue: TryFrom<V>,
228    {
229        let name: HeaderName = TryFrom::try_from(key)
230            .map_err(|_| ())
231            .expect("invalid header name");
232        let value = TryFrom::try_from(value)
233            .map_err(|_| ())
234            .expect("invalid header value");
235        self.req.headers_mut().insert(name, value);
236        self
237    }
238
239    /// Set the remote address of this request
240    ///
241    /// Default is no remote address.
242    ///
243    /// # Example
244    /// ```
245    /// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
246    ///
247    /// let req = warp::test::request()
248    ///     .remote_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
249    /// ```
250    pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
251        self.req.extensions_mut().insert(RemoteAddr(addr));
252        self
253    }
254
255    /// Add a type to the request's `http::Extensions`.
256    pub fn extension<T>(mut self, ext: T) -> Self
257    where
258        T: Clone + Send + Sync + 'static,
259    {
260        self.req.extensions_mut().insert(ext);
261        self
262    }
263
264    /// Set the bytes of this request body.
265    ///
266    /// Default is an empty body.
267    ///
268    /// # Example
269    ///
270    /// ```
271    /// let req = warp::test::request()
272    ///     .body("foo=bar&baz=quux");
273    /// ```
274    pub fn body(mut self, body: impl AsRef<[u8]>) -> Self {
275        let body = body.as_ref().to_vec();
276        let len = body.len();
277        *self.req.body_mut() = body.into();
278        self.header("content-length", len.to_string())
279    }
280
281    /// Set the bytes of this request body by serializing a value into JSON.
282    ///
283    /// # Example
284    ///
285    /// ```
286    /// let req = warp::test::request()
287    ///     .json(&true);
288    /// ```
289    pub fn json(mut self, val: &impl Serialize) -> Self {
290        let vec = serde_json::to_vec(val).expect("json() must serialize to JSON");
291        let len = vec.len();
292        *self.req.body_mut() = vec.into();
293        self.header("content-length", len.to_string())
294            .header("content-type", "application/json")
295    }
296
297    /// Tries to apply the `Filter` on this request.
298    ///
299    /// # Example
300    ///
301    /// ```no_run
302    /// async {
303    ///     let param = warp::path::param::<u32>();
304    ///
305    ///     let ex = warp::test::request()
306    ///         .path("/41")
307    ///         .filter(&param)
308    ///         .await
309    ///         .unwrap();
310    ///
311    ///     assert_eq!(ex, 41);
312    ///
313    ///     assert!(
314    ///         warp::test::request()
315    ///             .path("/foo")
316    ///             .filter(&param)
317    ///             .await
318    ///             .is_err()
319    ///     );
320    ///};
321    /// ```
322    pub async fn filter<F>(self, f: &F) -> Result<<F::Extract as OneOrTuple>::Output, F::Error>
323    where
324        F: Filter,
325        F::Future: Send + 'static,
326        F::Extract: OneOrTuple + Send + 'static,
327        F::Error: Send + 'static,
328    {
329        self.apply_filter(f).await.map(|ex| ex.one_or_tuple())
330    }
331
332    /// Returns whether the `Filter` matches this request, or rejects it.
333    ///
334    /// # Example
335    ///
336    /// ```no_run
337    /// async {
338    ///     let get = warp::get();
339    ///     let post = warp::post();
340    ///
341    ///     assert!(
342    ///         warp::test::request()
343    ///             .method("GET")
344    ///             .matches(&get)
345    ///             .await
346    ///     );
347    ///
348    ///     assert!(
349    ///         !warp::test::request()
350    ///             .method("GET")
351    ///             .matches(&post)
352    ///             .await
353    ///     );
354    ///};
355    /// ```
356    pub async fn matches<F>(self, f: &F) -> bool
357    where
358        F: Filter,
359        F::Future: Send + 'static,
360        F::Extract: Send + 'static,
361        F::Error: Send + 'static,
362    {
363        self.apply_filter(f).await.is_ok()
364    }
365
366    /// Returns `Response` provided by applying the `Filter`.
367    ///
368    /// This requires that the supplied `Filter` return a [`Reply`].
369    pub async fn reply<F>(self, f: &F) -> Response<Bytes>
370    where
371        F: Filter + 'static,
372        F::Extract: Reply + Send,
373        F::Error: IsReject + Send,
374    {
375        // TODO: de-duplicate this and apply_filter()
376        assert!(!route::is_set(), "nested test filter calls");
377
378        let route = Route::new(self.req);
379        let mut fut = Box::pin(
380            route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| {
381                let res = match result {
382                    Ok(rep) => rep.into_response(),
383                    Err(rej) => {
384                        tracing::debug!("rejected: {:?}", rej);
385                        rej.into_response()
386                    }
387                };
388                let (parts, body) = res.into_parts();
389                {
390                    body.collect()
391                        .map_ok(|chunk| Response::from_parts(parts, chunk.to_bytes()))
392                }
393            }),
394        );
395
396        let fut = future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)));
397
398        fut.await.expect("reply shouldn't fail")
399    }
400
401    fn apply_filter<F>(self, f: &F) -> impl Future<Output = Result<F::Extract, F::Error>>
402    where
403        F: Filter,
404        F::Future: Send + 'static,
405        F::Extract: Send + 'static,
406        F::Error: Send + 'static,
407    {
408        assert!(!route::is_set(), "nested test filter calls");
409
410        let route = Route::new(self.req);
411        let mut fut = Box::pin(route::set(&route, move || {
412            f.filter(crate::filter::Internal)
413        }));
414        future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)))
415    }
416}
417
418#[cfg(feature = "websocket")]
419impl WsBuilder {
420    /// Sets the request path of this builder.
421    ///
422    /// The default is not set is `/`.
423    ///
424    /// # Example
425    ///
426    /// ```
427    /// let req = warp::test::ws()
428    ///     .path("/chat");
429    /// ```
430    ///
431    /// # Panic
432    ///
433    /// This panics if the passed string is not able to be parsed as a valid
434    /// `Uri`.
435    pub fn path(self, p: &str) -> Self {
436        WsBuilder {
437            req: self.req.path(p),
438        }
439    }
440
441    /// Set a header for this request.
442    ///
443    /// # Example
444    ///
445    /// ```
446    /// let req = warp::test::ws()
447    ///     .header("foo", "bar");
448    /// ```
449    ///
450    /// # Panic
451    ///
452    /// This panics if the passed strings are not able to be parsed as a valid
453    /// `HeaderName` and `HeaderValue`.
454    pub fn header<K, V>(self, key: K, value: V) -> Self
455    where
456        HeaderName: TryFrom<K>,
457        HeaderValue: TryFrom<V>,
458    {
459        WsBuilder {
460            req: self.req.header(key, value),
461        }
462    }
463
464    /// Execute this Websocket request against the provided filter.
465    ///
466    /// If the handshake succeeds, returns a `WsClient`.
467    ///
468    /// # Example
469    ///
470    /// ```no_run
471    /// use futures_util::future;
472    /// use warp::Filter;
473    /// #[tokio::main]
474    /// # async fn main() {
475    ///
476    /// // Some route that accepts websockets (but drops them immediately).
477    /// let route = warp::ws()
478    ///     .map(|ws: warp::ws::Ws| {
479    ///         ws.on_upgrade(|_| future::ready(()))
480    ///     });
481    ///
482    /// let client = warp::test::ws()
483    ///     .handshake(route)
484    ///     .await
485    ///     .expect("handshake");
486    /// # }
487    /// ```
488    pub async fn handshake<F>(self, f: F) -> Result<WsClient, WsError>
489    where
490        F: Filter + Clone + Send + Sync + 'static,
491        F::Extract: Reply + Send,
492        F::Error: IsReject + Send,
493    {
494        let (upgraded_tx, upgraded_rx) = oneshot::channel();
495        let (wr_tx, wr_rx) = mpsc::unbounded();
496        let (rd_tx, rd_rx) = mpsc::unbounded();
497
498        tokio::spawn(async move {
499            use tokio_tungstenite::tungstenite::protocol;
500
501            let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
502                .await
503                .expect("binding");
504            let addr = listener.local_addr().unwrap();
505            tokio::spawn(async move {
506                crate::serve(f).incoming(listener).run().await;
507            });
508
509            let mut req = self
510                .req
511                .header("connection", "upgrade")
512                .header("upgrade", "websocket")
513                .header("sec-websocket-version", "13")
514                .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
515                .req;
516
517            let query_string = match req.uri().query() {
518                Some(q) => format!("?{}", q),
519                None => String::from(""),
520            };
521
522            let uri = format!("http://{}{}{}", addr, req.uri().path(), query_string)
523                .parse()
524                .expect("addr + path is valid URI");
525
526            *req.uri_mut() = uri;
527
528            let upgrade = async move {
529                let io = tokio::net::TcpStream::connect(addr).await?;
530                let io = hyper_util::rt::TokioIo::new(io);
531                let (mut tx, conn) = hyper::client::conn::http1::handshake(io).await?;
532                tokio::spawn(async move {
533                    let _ = conn.with_upgrades().await;
534                });
535                let res = tx.send_request(req).await?;
536                hyper::upgrade::on(res)
537                    .await
538                    .map_err(|e| Box::new(e) as Box<dyn StdError + Send + Sync>)
539            };
540
541            let upgraded = match upgrade.await {
542                Ok(up) => {
543                    let _ = upgraded_tx.send(Ok(()));
544                    up
545                }
546                Err(err) => {
547                    let _ = upgraded_tx.send(Err(err));
548                    return;
549                }
550            };
551            let ws = crate::ws::WebSocket::from_raw_socket(
552                upgraded,
553                protocol::Role::Client,
554                Default::default(),
555            )
556            .await;
557
558            let (tx, rx) = ws.split();
559            let write = wr_rx.map(Ok).forward(tx).map(|_| ());
560
561            let read = rx
562                .take_while(|result| match result {
563                    Err(_) => future::ready(false),
564                    Ok(m) => future::ready(!m.is_close()),
565                })
566                .for_each(move |item| {
567                    rd_tx.unbounded_send(item).expect("ws receive error");
568                    future::ready(())
569                });
570
571            future::join(write, read).await;
572        });
573
574        match upgraded_rx.await {
575            Ok(Ok(())) => Ok(WsClient {
576                tx: wr_tx,
577                rx: rd_rx,
578            }),
579            Ok(Err(err)) => Err(WsError::new(err)),
580            Err(_canceled) => panic!("websocket handshake thread panicked"),
581        }
582    }
583}
584
585#[cfg(feature = "websocket")]
586impl WsClient {
587    /// Send a "text" websocket message to the server.
588    pub async fn send_text(&mut self, text: impl Into<String>) {
589        self.send(crate::ws::Message::text(text.into())).await;
590    }
591
592    /// Send a websocket message to the server.
593    pub async fn send(&mut self, msg: crate::ws::Message) {
594        self.tx.unbounded_send(msg).unwrap();
595    }
596
597    /// Receive a websocket message from the server.
598    pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> {
599        self.rx
600            .next()
601            .await
602            .map(|result| result.map_err(WsError::new))
603            .unwrap_or_else(|| {
604                // websocket is closed
605                Err(WsError::new("closed"))
606            })
607    }
608
609    /// Assert the server has closed the connection.
610    pub async fn recv_closed(&mut self) -> Result<(), WsError> {
611        self.rx
612            .next()
613            .await
614            .map(|result| match result {
615                Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
616                Err(err) => Err(WsError::new(err)),
617            })
618            .unwrap_or_else(|| {
619                // closed successfully
620                Ok(())
621            })
622    }
623
624    fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender<crate::ws::Message>> {
625        let this = Pin::into_inner(self);
626        Pin::new(&mut this.tx)
627    }
628}
629
630#[cfg(feature = "websocket")]
631impl fmt::Debug for WsClient {
632    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
633        f.debug_struct("WsClient").finish()
634    }
635}
636
637#[cfg(feature = "websocket")]
638impl Sink<crate::ws::Message> for WsClient {
639    type Error = WsError;
640
641    fn poll_ready(
642        self: Pin<&mut Self>,
643        context: &mut Context<'_>,
644    ) -> Poll<Result<(), Self::Error>> {
645        self.pinned_tx().poll_ready(context).map_err(WsError::new)
646    }
647
648    fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
649        self.pinned_tx().start_send(message).map_err(WsError::new)
650    }
651
652    fn poll_flush(
653        self: Pin<&mut Self>,
654        context: &mut Context<'_>,
655    ) -> Poll<Result<(), Self::Error>> {
656        self.pinned_tx().poll_flush(context).map_err(WsError::new)
657    }
658
659    fn poll_close(
660        self: Pin<&mut Self>,
661        context: &mut Context<'_>,
662    ) -> Poll<Result<(), Self::Error>> {
663        self.pinned_tx().poll_close(context).map_err(WsError::new)
664    }
665}
666
667#[cfg(feature = "websocket")]
668impl Stream for WsClient {
669    type Item = Result<crate::ws::Message, WsError>;
670
671    fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
672        let this = Pin::into_inner(self);
673        let rx = Pin::new(&mut this.rx);
674        match rx.poll_next(context) {
675            Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))),
676            Poll::Ready(None) => Poll::Ready(None),
677            Poll::Pending => Poll::Pending,
678        }
679    }
680}
681
682// ===== impl WsError =====
683
684#[cfg(feature = "websocket")]
685impl WsError {
686    fn new<E: Into<Box<dyn StdError + Send + Sync>>>(cause: E) -> Self {
687        WsError {
688            cause: cause.into(),
689        }
690    }
691}
692
693impl fmt::Display for WsError {
694    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
695        write!(f, "websocket error: {}", self.cause)
696    }
697}
698
699impl StdError for WsError {
700    fn description(&self) -> &str {
701        "websocket error"
702    }
703}
704
705mod inner {
706    pub trait OneOrTuple {
707        type Output;
708
709        fn one_or_tuple(self) -> Self::Output;
710    }
711
712    impl OneOrTuple for () {
713        type Output = ();
714        fn one_or_tuple(self) -> Self::Output {}
715    }
716
717    macro_rules! one_or_tuple {
718        ($type1:ident) => {
719            impl<$type1> OneOrTuple for ($type1,) {
720                type Output = $type1;
721                fn one_or_tuple(self) -> Self::Output {
722                    self.0
723                }
724            }
725        };
726        ($type1:ident, $( $type:ident ),*) => {
727            one_or_tuple!($( $type ),*);
728
729            impl<$type1, $($type),*> OneOrTuple for ($type1, $($type),*) {
730                type Output = Self;
731                fn one_or_tuple(self) -> Self::Output {
732                    self
733                }
734            }
735        }
736    }
737
738    one_or_tuple! {
739        T1,
740        T2,
741        T3,
742        T4,
743        T5,
744        T6,
745        T7,
746        T8,
747        T9,
748        T10,
749        T11,
750        T12,
751        T13,
752        T14,
753        T15,
754        T16
755    }
756}