1use std::convert::TryFrom;
85use std::error::Error as StdError;
86use std::fmt;
87use std::future::Future;
88use std::net::SocketAddr;
89#[cfg(feature = "websocket")]
90use std::pin::Pin;
91#[cfg(feature = "websocket")]
92use std::task::Context;
93#[cfg(feature = "websocket")]
94use std::task::Poll;
95
96use bytes::Bytes;
97#[cfg(feature = "websocket")]
98use futures_channel::mpsc;
99#[cfg(feature = "websocket")]
100use futures_util::StreamExt;
101use futures_util::{future, FutureExt, TryFutureExt};
102use http::{
103 header::{HeaderName, HeaderValue},
104 Response,
105};
106use http_body_util::BodyExt;
107use serde::Serialize;
108#[cfg(feature = "websocket")]
109use tokio::sync::oneshot;
110
111use crate::filter::Filter;
112use crate::filters::addr::RemoteAddr;
113#[cfg(feature = "websocket")]
114use crate::filters::ws::Message;
115use crate::reject::IsReject;
116use crate::reply::Reply;
117use crate::route::{self, Route};
118use crate::Request;
119#[cfg(feature = "websocket")]
120use crate::{Sink, Stream};
121
122use self::inner::OneOrTuple;
123
124pub fn request() -> RequestBuilder {
126 RequestBuilder {
127 req: Request::default(),
128 }
129}
130
131#[cfg(feature = "websocket")]
133pub fn ws() -> WsBuilder {
134 WsBuilder { req: request() }
135}
136
137#[must_use = "RequestBuilder does nothing on its own"]
141#[derive(Debug)]
142pub struct RequestBuilder {
143 req: Request,
144}
145
146#[cfg(feature = "websocket")]
150#[must_use = "WsBuilder does nothing on its own"]
151#[derive(Debug)]
152pub struct WsBuilder {
153 req: RequestBuilder,
154}
155
156#[cfg(feature = "websocket")]
158pub struct WsClient {
159 tx: mpsc::UnboundedSender<crate::ws::Message>,
160 rx: mpsc::UnboundedReceiver<Result<crate::ws::Message, crate::error::Error>>,
161}
162
163#[derive(Debug)]
165pub struct WsError {
166 cause: Box<dyn StdError + Send + Sync>,
167}
168
169impl RequestBuilder {
170 pub fn method(mut self, method: &str) -> Self {
186 *self.req.method_mut() = method.parse().expect("valid method");
187 self
188 }
189
190 pub fn path(mut self, p: &str) -> Self {
206 let uri = p.parse().expect("test request path invalid");
207 *self.req.uri_mut() = uri;
208 self
209 }
210
211 pub fn header<K, V>(mut self, key: K, value: V) -> Self
225 where
226 HeaderName: TryFrom<K>,
227 HeaderValue: TryFrom<V>,
228 {
229 let name: HeaderName = TryFrom::try_from(key)
230 .map_err(|_| ())
231 .expect("invalid header name");
232 let value = TryFrom::try_from(value)
233 .map_err(|_| ())
234 .expect("invalid header value");
235 self.req.headers_mut().insert(name, value);
236 self
237 }
238
239 pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
251 self.req.extensions_mut().insert(RemoteAddr(addr));
252 self
253 }
254
255 pub fn extension<T>(mut self, ext: T) -> Self
257 where
258 T: Clone + Send + Sync + 'static,
259 {
260 self.req.extensions_mut().insert(ext);
261 self
262 }
263
264 pub fn body(mut self, body: impl AsRef<[u8]>) -> Self {
275 let body = body.as_ref().to_vec();
276 let len = body.len();
277 *self.req.body_mut() = body.into();
278 self.header("content-length", len.to_string())
279 }
280
281 pub fn json(mut self, val: &impl Serialize) -> Self {
290 let vec = serde_json::to_vec(val).expect("json() must serialize to JSON");
291 let len = vec.len();
292 *self.req.body_mut() = vec.into();
293 self.header("content-length", len.to_string())
294 .header("content-type", "application/json")
295 }
296
297 pub async fn filter<F>(self, f: &F) -> Result<<F::Extract as OneOrTuple>::Output, F::Error>
323 where
324 F: Filter,
325 F::Future: Send + 'static,
326 F::Extract: OneOrTuple + Send + 'static,
327 F::Error: Send + 'static,
328 {
329 self.apply_filter(f).await.map(|ex| ex.one_or_tuple())
330 }
331
332 pub async fn matches<F>(self, f: &F) -> bool
357 where
358 F: Filter,
359 F::Future: Send + 'static,
360 F::Extract: Send + 'static,
361 F::Error: Send + 'static,
362 {
363 self.apply_filter(f).await.is_ok()
364 }
365
366 pub async fn reply<F>(self, f: &F) -> Response<Bytes>
370 where
371 F: Filter + 'static,
372 F::Extract: Reply + Send,
373 F::Error: IsReject + Send,
374 {
375 assert!(!route::is_set(), "nested test filter calls");
377
378 let route = Route::new(self.req);
379 let mut fut = Box::pin(
380 route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| {
381 let res = match result {
382 Ok(rep) => rep.into_response(),
383 Err(rej) => {
384 tracing::debug!("rejected: {:?}", rej);
385 rej.into_response()
386 }
387 };
388 let (parts, body) = res.into_parts();
389 {
390 body.collect()
391 .map_ok(|chunk| Response::from_parts(parts, chunk.to_bytes()))
392 }
393 }),
394 );
395
396 let fut = future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)));
397
398 fut.await.expect("reply shouldn't fail")
399 }
400
401 fn apply_filter<F>(self, f: &F) -> impl Future<Output = Result<F::Extract, F::Error>>
402 where
403 F: Filter,
404 F::Future: Send + 'static,
405 F::Extract: Send + 'static,
406 F::Error: Send + 'static,
407 {
408 assert!(!route::is_set(), "nested test filter calls");
409
410 let route = Route::new(self.req);
411 let mut fut = Box::pin(route::set(&route, move || {
412 f.filter(crate::filter::Internal)
413 }));
414 future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)))
415 }
416}
417
418#[cfg(feature = "websocket")]
419impl WsBuilder {
420 pub fn path(self, p: &str) -> Self {
436 WsBuilder {
437 req: self.req.path(p),
438 }
439 }
440
441 pub fn header<K, V>(self, key: K, value: V) -> Self
455 where
456 HeaderName: TryFrom<K>,
457 HeaderValue: TryFrom<V>,
458 {
459 WsBuilder {
460 req: self.req.header(key, value),
461 }
462 }
463
464 pub async fn handshake<F>(self, f: F) -> Result<WsClient, WsError>
489 where
490 F: Filter + Clone + Send + Sync + 'static,
491 F::Extract: Reply + Send,
492 F::Error: IsReject + Send,
493 {
494 let (upgraded_tx, upgraded_rx) = oneshot::channel();
495 let (wr_tx, wr_rx) = mpsc::unbounded();
496 let (rd_tx, rd_rx) = mpsc::unbounded();
497
498 tokio::spawn(async move {
499 use tokio_tungstenite::tungstenite::protocol;
500
501 let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
502 .await
503 .expect("binding");
504 let addr = listener.local_addr().unwrap();
505 tokio::spawn(async move {
506 crate::serve(f).incoming(listener).run().await;
507 });
508
509 let mut req = self
510 .req
511 .header("connection", "upgrade")
512 .header("upgrade", "websocket")
513 .header("sec-websocket-version", "13")
514 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
515 .req;
516
517 let query_string = match req.uri().query() {
518 Some(q) => format!("?{}", q),
519 None => String::from(""),
520 };
521
522 let uri = format!("http://{}{}{}", addr, req.uri().path(), query_string)
523 .parse()
524 .expect("addr + path is valid URI");
525
526 *req.uri_mut() = uri;
527
528 let upgrade = async move {
529 let io = tokio::net::TcpStream::connect(addr).await?;
530 let io = hyper_util::rt::TokioIo::new(io);
531 let (mut tx, conn) = hyper::client::conn::http1::handshake(io).await?;
532 tokio::spawn(async move {
533 let _ = conn.with_upgrades().await;
534 });
535 let res = tx.send_request(req).await?;
536 hyper::upgrade::on(res)
537 .await
538 .map_err(|e| Box::new(e) as Box<dyn StdError + Send + Sync>)
539 };
540
541 let upgraded = match upgrade.await {
542 Ok(up) => {
543 let _ = upgraded_tx.send(Ok(()));
544 up
545 }
546 Err(err) => {
547 let _ = upgraded_tx.send(Err(err));
548 return;
549 }
550 };
551 let ws = crate::ws::WebSocket::from_raw_socket(
552 upgraded,
553 protocol::Role::Client,
554 Default::default(),
555 )
556 .await;
557
558 let (tx, rx) = ws.split();
559 let write = wr_rx.map(Ok).forward(tx).map(|_| ());
560
561 let read = rx
562 .take_while(|result| match result {
563 Err(_) => future::ready(false),
564 Ok(m) => future::ready(!m.is_close()),
565 })
566 .for_each(move |item| {
567 rd_tx.unbounded_send(item).expect("ws receive error");
568 future::ready(())
569 });
570
571 future::join(write, read).await;
572 });
573
574 match upgraded_rx.await {
575 Ok(Ok(())) => Ok(WsClient {
576 tx: wr_tx,
577 rx: rd_rx,
578 }),
579 Ok(Err(err)) => Err(WsError::new(err)),
580 Err(_canceled) => panic!("websocket handshake thread panicked"),
581 }
582 }
583}
584
585#[cfg(feature = "websocket")]
586impl WsClient {
587 pub async fn send_text(&mut self, text: impl Into<String>) {
589 self.send(crate::ws::Message::text(text.into())).await;
590 }
591
592 pub async fn send(&mut self, msg: crate::ws::Message) {
594 self.tx.unbounded_send(msg).unwrap();
595 }
596
597 pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> {
599 self.rx
600 .next()
601 .await
602 .map(|result| result.map_err(WsError::new))
603 .unwrap_or_else(|| {
604 Err(WsError::new("closed"))
606 })
607 }
608
609 pub async fn recv_closed(&mut self) -> Result<(), WsError> {
611 self.rx
612 .next()
613 .await
614 .map(|result| match result {
615 Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
616 Err(err) => Err(WsError::new(err)),
617 })
618 .unwrap_or_else(|| {
619 Ok(())
621 })
622 }
623
624 fn pinned_tx(self: Pin<&mut Self>) -> Pin<&mut mpsc::UnboundedSender<crate::ws::Message>> {
625 let this = Pin::into_inner(self);
626 Pin::new(&mut this.tx)
627 }
628}
629
630#[cfg(feature = "websocket")]
631impl fmt::Debug for WsClient {
632 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
633 f.debug_struct("WsClient").finish()
634 }
635}
636
637#[cfg(feature = "websocket")]
638impl Sink<crate::ws::Message> for WsClient {
639 type Error = WsError;
640
641 fn poll_ready(
642 self: Pin<&mut Self>,
643 context: &mut Context<'_>,
644 ) -> Poll<Result<(), Self::Error>> {
645 self.pinned_tx().poll_ready(context).map_err(WsError::new)
646 }
647
648 fn start_send(self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
649 self.pinned_tx().start_send(message).map_err(WsError::new)
650 }
651
652 fn poll_flush(
653 self: Pin<&mut Self>,
654 context: &mut Context<'_>,
655 ) -> Poll<Result<(), Self::Error>> {
656 self.pinned_tx().poll_flush(context).map_err(WsError::new)
657 }
658
659 fn poll_close(
660 self: Pin<&mut Self>,
661 context: &mut Context<'_>,
662 ) -> Poll<Result<(), Self::Error>> {
663 self.pinned_tx().poll_close(context).map_err(WsError::new)
664 }
665}
666
667#[cfg(feature = "websocket")]
668impl Stream for WsClient {
669 type Item = Result<crate::ws::Message, WsError>;
670
671 fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
672 let this = Pin::into_inner(self);
673 let rx = Pin::new(&mut this.rx);
674 match rx.poll_next(context) {
675 Poll::Ready(Some(result)) => Poll::Ready(Some(result.map_err(WsError::new))),
676 Poll::Ready(None) => Poll::Ready(None),
677 Poll::Pending => Poll::Pending,
678 }
679 }
680}
681
682#[cfg(feature = "websocket")]
685impl WsError {
686 fn new<E: Into<Box<dyn StdError + Send + Sync>>>(cause: E) -> Self {
687 WsError {
688 cause: cause.into(),
689 }
690 }
691}
692
693impl fmt::Display for WsError {
694 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
695 write!(f, "websocket error: {}", self.cause)
696 }
697}
698
699impl StdError for WsError {
700 fn description(&self) -> &str {
701 "websocket error"
702 }
703}
704
705mod inner {
706 pub trait OneOrTuple {
707 type Output;
708
709 fn one_or_tuple(self) -> Self::Output;
710 }
711
712 impl OneOrTuple for () {
713 type Output = ();
714 fn one_or_tuple(self) -> Self::Output {}
715 }
716
717 macro_rules! one_or_tuple {
718 ($type1:ident) => {
719 impl<$type1> OneOrTuple for ($type1,) {
720 type Output = $type1;
721 fn one_or_tuple(self) -> Self::Output {
722 self.0
723 }
724 }
725 };
726 ($type1:ident, $( $type:ident ),*) => {
727 one_or_tuple!($( $type ),*);
728
729 impl<$type1, $($type),*> OneOrTuple for ($type1, $($type),*) {
730 type Output = Self;
731 fn one_or_tuple(self) -> Self::Output {
732 self
733 }
734 }
735 }
736 }
737
738 one_or_tuple! {
739 T1,
740 T2,
741 T3,
742 T4,
743 T5,
744 T6,
745 T7,
746 T8,
747 T9,
748 T10,
749 T11,
750 T12,
751 T13,
752 T14,
753 T15,
754 T16
755 }
756}