1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
use std::{
    io,
    marker::PhantomData,
    pin::Pin,
    task::{Context, Poll},
};

use byteorder::{NetworkEndian, WriteBytesExt};
use futures_core::ready;
use futures_sink::Sink;
use prost::Message;
use tokio::io::AsyncWrite;

use crate::{AsyncDestination, AsyncFrameDestination, Framed, SyncDestination};

/// A warpper around an async sink that accepts, serializes, and sends prost-encoded values.
#[derive(Debug)]
pub struct AsyncProstWriter<W, T, D> {
    writer: W,
    pub(crate) written: usize,
    pub(crate) buffer: Vec<u8>,
    from: PhantomData<T>,
    dest: PhantomData<D>,
}

impl<W, T, D> AsyncProstWriter<W, T, D> {
    /// create a new async prost writer
    pub fn new(writer: W) -> Self {
        Self {
            writer,
            written: 0,
            buffer: Vec::new(),
            from: PhantomData,
            dest: PhantomData,
        }
    }

    /// Gets a reference to the underlying writer.
    pub fn get_ref(&self) -> &W {
        &self.writer
    }

    /// Gets a mutable reference to the underlying writer.
    pub fn get_mut(&mut self) -> &mut W {
        &mut self.writer
    }

    /// Unwraps this `AsyncProstWriter`, returning the underlying writer.
    ///
    /// Note that any leftover serialized data that has not yet been sent is lost.
    pub fn into_inner(self) -> W {
        self.writer
    }

    pub(crate) fn make_for<D2>(self) -> AsyncProstWriter<W, T, D2> {
        AsyncProstWriter {
            buffer: self.buffer,
            writer: self.writer,
            written: self.written,
            from: self.from,
            dest: PhantomData,
        }
    }
}

impl<W, T, D> Unpin for AsyncProstWriter<W, T, D> {}

impl<W, T> Default for AsyncProstWriter<W, T, SyncDestination>
where
    W: Default,
{
    fn default() -> Self {
        Self::from(W::default())
    }
}

impl<W, T> From<W> for AsyncProstWriter<W, T, SyncDestination> {
    fn from(writer: W) -> Self {
        Self::new(writer)
    }
}

impl<W, T> AsyncProstWriter<W, T, SyncDestination> {
    /// make this writer include the serialized data's size before each serialized value.
    pub fn for_async(self) -> AsyncProstWriter<W, T, AsyncDestination> {
        self.make_for()
    }

    /// make this writer include the serialized data's header and body size before serialized value
    pub fn for_async_framed(self) -> AsyncProstWriter<W, T, AsyncFrameDestination> {
        self.make_for()
    }
}

#[doc(hidden)]
pub trait ProstWriterFor<T> {
    fn append(&mut self, item: T) -> Result<(), io::Error>;
}

impl<W, F: Framed> ProstWriterFor<F> for AsyncProstWriter<W, F, AsyncFrameDestination> {
    fn append(&mut self, item: F) -> Result<(), io::Error> {
        let size = item.encoded_len();
        self.buffer.write_u32::<NetworkEndian>(size)?;
        item.encode(&mut self.buffer)?;
        Ok(())
    }
}

impl<W, T: Message> ProstWriterFor<T> for AsyncProstWriter<W, T, AsyncDestination> {
    fn append(&mut self, item: T) -> Result<(), io::Error> {
        let size = item.encoded_len() as u32;

        self.buffer.write_u32::<NetworkEndian>(size)?;
        item.encode(&mut self.buffer)?;
        Ok(())
    }
}

// FIXME: why do we need this impl without writing the size?
impl<W, T> ProstWriterFor<T> for AsyncProstWriter<W, T, SyncDestination>
where
    T: Message,
{
    fn append(&mut self, item: T) -> Result<(), io::Error> {
        item.encode(&mut self.buffer)?;
        Ok(())
    }
}

impl<W, T, D> Sink<T> for AsyncProstWriter<W, T, D>
where
    W: AsyncWrite + Unpin,
    Self: ProstWriterFor<T>,
{
    type Error = io::Error;

    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
        if self.buffer.is_empty() {
            // NOTE: in theory we could have a short-circuit here that tries to have prost write
            // directly into self.writer. this would be way more efficient in the common case as we
            // don't have to do the extra buffering. the idea would be to serialize fist, and *if*
            // it errors, see how many bytes were written, serialize again into a Vec, and then
            // keep only the bytes following the number that were written in our buffer.
            // unfortunately, prost will not tell us that number at the moment, and instead just
            // fail.
        }

        self.append(item)
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        let this = self.get_mut();

        // write stuff out if we need to
        while this.written != this.buffer.len() {
            let n =
                ready!(Pin::new(&mut this.writer).poll_write(cx, &this.buffer[this.written..]))?;
            this.written += n;
        }

        // we have to flush before we're really done
        this.buffer.clear();
        this.written = 0;
        Pin::new(&mut this.writer).poll_flush(cx)
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        ready!(self.as_mut().poll_flush(cx))?;
        Pin::new(&mut self.writer).poll_shutdown(cx)
    }
}