fake_socket/
lib.rs

1use pin_project::pin_project;
2use std::{
3    marker::PhantomData,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures::{Sink, Stream};
9use tokio::sync::mpsc;
10
11#[pin_project]
12pub struct ReceiverStream<T, E> {
13    #[pin]
14    inner: mpsc::UnboundedReceiver<T>,
15    _error: PhantomData<E>,
16}
17
18pub struct SenderSink<T, E> {
19    inner: mpsc::UnboundedSender<T>,
20    _error: PhantomData<E>,
21}
22
23#[pin_project]
24pub struct FakeSocket<T, E> {
25    #[pin]
26    sender: SenderSink<T, E>,
27    #[pin]
28    receiver: ReceiverStream<T, E>,
29}
30
31pub struct FakeClient<T> {
32    sender: mpsc::UnboundedSender<T>,
33    receiver: mpsc::UnboundedReceiver<T>,
34}
35
36impl<T, E> ReceiverStream<T, E> {
37    pub fn new(inner: mpsc::UnboundedReceiver<T>) -> Self {
38        Self {
39            inner,
40            _error: PhantomData::default(),
41        }
42    }
43}
44
45impl<T, E> SenderSink<T, E> {
46    pub fn new(inner: mpsc::UnboundedSender<T>) -> Self {
47        Self {
48            inner,
49            _error: PhantomData::default(),
50        }
51    }
52}
53
54impl<T, E> FakeSocket<T, E> {
55    pub fn new(rx: mpsc::UnboundedReceiver<T>, tx: mpsc::UnboundedSender<T>) -> Self {
56        Self {
57            sender: SenderSink::new(tx),
58            receiver: ReceiverStream::new(rx),
59        }
60    }
61}
62
63impl<T, E> Stream for ReceiverStream<T, E> {
64    type Item = Result<T, E>;
65
66    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67        let data = futures::ready!(self.inner.poll_recv(cx));
68        Poll::Ready(Ok(data).transpose())
69    }
70}
71
72impl<T, E> Stream for FakeSocket<T, E> {
73    type Item = Result<T, E>;
74
75    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
76        self.project().receiver.poll_next(cx)
77    }
78}
79
80impl<T, E> Sink<T> for SenderSink<T, E> {
81    type Error = E;
82
83    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        Poll::Ready(Ok(()))
85    }
86
87    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
88        let _ = self.inner.send(item);
89        Ok(())
90    }
91
92    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
93        Poll::Ready(Ok(()))
94    }
95
96    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
97        Poll::Ready(Ok(()))
98    }
99}
100
101impl<T, E> Sink<T> for FakeSocket<T, E> {
102    type Error = E;
103
104    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105        self.project().sender.poll_ready(cx)
106    }
107
108    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
109        self.project().sender.start_send(item)
110    }
111
112    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113        self.project().sender.poll_flush(cx)
114    }
115
116    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117        self.project().sender.poll_close(cx)
118    }
119}
120
121impl<T> FakeClient<T> {
122    pub fn new(sender: mpsc::UnboundedSender<T>, receiver: mpsc::UnboundedReceiver<T>) -> Self {
123        Self { sender, receiver }
124    }
125
126    pub fn send(&self, msg: T) -> Result<(), mpsc::error::SendError<T>> {
127        self.sender.send(msg)
128    }
129
130    pub async fn recv(&mut self) -> Option<T> {
131        self.receiver.recv().await
132    }
133}
134
135/// Create fake client and fake socket. The socket could be sent to the function to be tested.
136/// For example:
137/// ```
138/// let (mut client, socket) = create_fake_connect();
139/// tokio::spawn(async move {
140///     handle_socket(socket, state).await;
141/// });
142///
143/// let msg = ...;
144/// client.send(msg).await;
145/// if let Some(msg1) = client.recv().await {
146///    assert_eq!(msg1, ...);
147/// }
148/// ```
149pub fn create_fake_connection<T, E>() -> (FakeClient<T>, FakeSocket<T, E>) {
150    let (tx1, rx1) = mpsc::unbounded_channel();
151    let (tx2, rx2) = mpsc::unbounded_channel();
152    let socket = FakeSocket::<T, E>::new(rx1, tx2);
153    let client = FakeClient::new(tx1, rx2);
154    (client, socket)
155}