1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use crate::RawTransportWrite;
use futures::ready;
use std::{
    fmt,
    future::Future,
    io,
    pin::Pin,
    task::{Context, Poll},
};
use tokio::{io::AsyncWrite, sync::mpsc};

/// Write portion of an inmemory channel
pub struct InmemoryTransportWriteHalf {
    tx: Option<mpsc::Sender<Vec<u8>>>,
    task: Option<Pin<Box<dyn Future<Output = io::Result<usize>> + Send + Sync + 'static>>>,
}

impl InmemoryTransportWriteHalf {
    pub fn new(tx: mpsc::Sender<Vec<u8>>) -> Self {
        Self {
            tx: Some(tx),
            task: None,
        }
    }
}

impl fmt::Debug for InmemoryTransportWriteHalf {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("InmemoryTransportWrite")
            .field("tx", &self.tx)
            .field(
                "task",
                &if self.tx.is_some() {
                    "Some(...)"
                } else {
                    "None"
                },
            )
            .finish()
    }
}

impl RawTransportWrite for InmemoryTransportWriteHalf {}

impl AsyncWrite for InmemoryTransportWriteHalf {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        loop {
            match self.task.as_mut() {
                Some(task) => {
                    let res = ready!(task.as_mut().poll(cx));
                    self.task.take();
                    return Poll::Ready(res);
                }
                None => match self.tx.as_mut() {
                    Some(tx) => {
                        let n = buf.len();
                        let tx_2 = tx.clone();
                        let data = buf.to_vec();
                        let task =
                            Box::pin(async move { tx_2.send(data).await.map(|_| n).or(Ok(0)) });
                        self.task.replace(task);
                    }
                    None => return Poll::Ready(Ok(0)),
                },
            }
        }
    }

    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
        Poll::Ready(Ok(()))
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.tx.take();
        self.task.take();
        Poll::Ready(Ok(()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{InmemoryTransport, IntoSplit};
    use tokio::io::AsyncWriteExt;

    #[tokio::test]
    async fn write_half_should_return_buf_len_if_can_send_immediately() {
        let (_tx, mut rx, transport) = InmemoryTransport::make(1);
        let (mut t_write, _t_read) = transport.into_split();

        // Write that is not waiting should always succeed with full contents
        let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
        assert_eq!(n, 3, "Unexpected byte count returned");

        // Verify we actually had the data sent
        let data = rx.try_recv().expect("Failed to recv data");
        assert_eq!(data, &[1, 2, 3]);
    }

    #[tokio::test]
    async fn write_half_should_return_support_eventually_sending_by_retrying_when_not_ready() {
        let (_tx, mut rx, transport) = InmemoryTransport::make(1);
        let (mut t_write, _t_read) = transport.into_split();

        // Queue a write already so that we block on the next one
        let _ = t_write.write(&[1, 2, 3]).await.expect("Failed to write");

        // Verify that the next write is pending
        let f = t_write.write(&[4, 5]);
        tokio::pin!(f);
        match futures::poll!(&mut f) {
            Poll::Pending => {}
            x => panic!("Unexpected poll result: {:?}", x),
        }

        // Consume first batch of data so future of second can continue
        let data = rx.try_recv().expect("Failed to recv data");
        assert_eq!(data, &[1, 2, 3]);

        // Verify that poll now returns success
        match futures::poll!(f) {
            Poll::Ready(Ok(n)) if n == 2 => {}
            x => panic!("Unexpected poll result: {:?}", x),
        }

        // Consume second batch of data
        let data = rx.try_recv().expect("Failed to recv data");
        assert_eq!(data, &[4, 5]);
    }

    #[tokio::test]
    async fn write_half_should_zero_if_inner_channel_closed() {
        let (_tx, rx, transport) = InmemoryTransport::make(1);
        let (mut t_write, _t_read) = transport.into_split();

        // Drop receiving end that transport would talk to
        drop(rx);

        // Channel is dropped, so return 0 to indicate no bytes sent
        let n = t_write.write(&[1, 2, 3]).await.expect("Failed to write");
        assert_eq!(n, 0, "Unexpected byte count returned");
    }
}