1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
use crate::encode::AsyncEncoder;

use std::pin::Pin;
use std::task::{Context, Poll};

use tokio::io::AsyncWrite;

type Tx<T = Vec<u8>> = tokio::sync::mpsc::Sender<T>;

/// A writer that allows sending messages to the client
// TODO this was renamed
pub type AsyncWriter = AsyncEncoder<AsyncMpscWriter>;

/// A tokio mpsc based writer
pub struct AsyncMpscWriter {
    buffer: Vec<u8>,
    sender: Tx,
}

impl std::fmt::Debug for AsyncMpscWriter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("AsyncMpscWriter").finish()
    }
}

impl Clone for AsyncMpscWriter {
    fn clone(&self) -> Self {
        Self {
            buffer: Vec::new(),
            sender: self.sender.clone(),
        }
    }
}

impl AsyncMpscWriter {
    /// Create a new AsyncMpscWriter from this channel's sender
    pub const fn new(sender: Tx) -> Self {
        Self {
            buffer: Vec::new(),
            sender,
        }
    }
}

impl AsyncWrite for AsyncMpscWriter {
    fn poll_write(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        self.get_mut().buffer.extend_from_slice(buf);
        Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        use std::io::{Error, ErrorKind};
        use tokio::sync::mpsc::error;

        let this = self.get_mut();
        let data = std::mem::take(&mut this.buffer);

        match this.sender.try_send(data) {
            Ok(..) => Poll::Ready(Ok(())),
            Err(error::TrySendError::Closed(..)) => {
                let err = Err(Error::new(ErrorKind::Other, "client disconnected"));
                Poll::Ready(err)
            }
            _ => unreachable!(),
        }
    }

    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        Poll::Ready(Ok(()))
    }
}