async_prost/
writer.rs

1use std::{
2    io,
3    marker::PhantomData,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use byteorder::{NetworkEndian, WriteBytesExt};
9use futures_core::ready;
10use futures_sink::Sink;
11use prost::Message;
12use tokio::io::AsyncWrite;
13
14use crate::{AsyncDestination, AsyncFrameDestination, Framed, SyncDestination};
15
16/// A warpper around an async sink that accepts, serializes, and sends prost-encoded values.
17#[derive(Debug)]
18pub struct AsyncProstWriter<W, T, D> {
19    writer: W,
20    pub(crate) written: usize,
21    pub(crate) buffer: Vec<u8>,
22    from: PhantomData<T>,
23    dest: PhantomData<D>,
24}
25
26impl<W, T, D> AsyncProstWriter<W, T, D> {
27    /// create a new async prost writer
28    pub fn new(writer: W) -> Self {
29        Self {
30            writer,
31            written: 0,
32            buffer: Vec::new(),
33            from: PhantomData,
34            dest: PhantomData,
35        }
36    }
37
38    /// Gets a reference to the underlying writer.
39    pub fn get_ref(&self) -> &W {
40        &self.writer
41    }
42
43    /// Gets a mutable reference to the underlying writer.
44    pub fn get_mut(&mut self) -> &mut W {
45        &mut self.writer
46    }
47
48    /// Unwraps this `AsyncProstWriter`, returning the underlying writer.
49    ///
50    /// Note that any leftover serialized data that has not yet been sent is lost.
51    pub fn into_inner(self) -> W {
52        self.writer
53    }
54
55    pub(crate) fn make_for<D2>(self) -> AsyncProstWriter<W, T, D2> {
56        AsyncProstWriter {
57            buffer: self.buffer,
58            writer: self.writer,
59            written: self.written,
60            from: self.from,
61            dest: PhantomData,
62        }
63    }
64}
65
66impl<W, T, D> Unpin for AsyncProstWriter<W, T, D> {}
67
68impl<W, T> Default for AsyncProstWriter<W, T, SyncDestination>
69where
70    W: Default,
71{
72    fn default() -> Self {
73        Self::from(W::default())
74    }
75}
76
77impl<W, T> From<W> for AsyncProstWriter<W, T, SyncDestination> {
78    fn from(writer: W) -> Self {
79        Self::new(writer)
80    }
81}
82
83impl<W, T> AsyncProstWriter<W, T, SyncDestination> {
84    /// make this writer include the serialized data's size before each serialized value.
85    pub fn for_async(self) -> AsyncProstWriter<W, T, AsyncDestination> {
86        self.make_for()
87    }
88
89    /// make this writer include the serialized data's header and body size before serialized value
90    pub fn for_async_framed(self) -> AsyncProstWriter<W, T, AsyncFrameDestination> {
91        self.make_for()
92    }
93}
94
95#[doc(hidden)]
96pub trait ProstWriterFor<T> {
97    fn append(&mut self, item: T) -> Result<(), io::Error>;
98}
99
100impl<W, F: Framed> ProstWriterFor<F> for AsyncProstWriter<W, F, AsyncFrameDestination> {
101    fn append(&mut self, item: F) -> Result<(), io::Error> {
102        let size = item.encoded_len();
103        self.buffer.write_u32::<NetworkEndian>(size)?;
104        item.encode(&mut self.buffer)?;
105        Ok(())
106    }
107}
108
109impl<W, T: Message> ProstWriterFor<T> for AsyncProstWriter<W, T, AsyncDestination> {
110    fn append(&mut self, item: T) -> Result<(), io::Error> {
111        let size = item.encoded_len() as u32;
112
113        self.buffer.write_u32::<NetworkEndian>(size)?;
114        item.encode(&mut self.buffer)?;
115        Ok(())
116    }
117}
118
119// FIXME: why do we need this impl without writing the size?
120impl<W, T> ProstWriterFor<T> for AsyncProstWriter<W, T, SyncDestination>
121where
122    T: Message,
123{
124    fn append(&mut self, item: T) -> Result<(), io::Error> {
125        item.encode(&mut self.buffer)?;
126        Ok(())
127    }
128}
129
130impl<W, T, D> Sink<T> for AsyncProstWriter<W, T, D>
131where
132    W: AsyncWrite + Unpin,
133    Self: ProstWriterFor<T>,
134{
135    type Error = io::Error;
136
137    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138        Poll::Ready(Ok(()))
139    }
140
141    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
142        if self.buffer.is_empty() {
143            // NOTE: in theory we could have a short-circuit here that tries to have prost write
144            // directly into self.writer. this would be way more efficient in the common case as we
145            // don't have to do the extra buffering. the idea would be to serialize fist, and *if*
146            // it errors, see how many bytes were written, serialize again into a Vec, and then
147            // keep only the bytes following the number that were written in our buffer.
148            // unfortunately, prost will not tell us that number at the moment, and instead just
149            // fail.
150        }
151
152        self.append(item)
153    }
154
155    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156        let this = self.get_mut();
157
158        // write stuff out if we need to
159        while this.written != this.buffer.len() {
160            let n =
161                ready!(Pin::new(&mut this.writer).poll_write(cx, &this.buffer[this.written..]))?;
162            this.written += n;
163        }
164
165        // we have to flush before we're really done
166        this.buffer.clear();
167        this.written = 0;
168        Pin::new(&mut this.writer).poll_flush(cx)
169    }
170
171    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
172        ready!(self.as_mut().poll_flush(cx))?;
173        Pin::new(&mut self.writer).poll_shutdown(cx)
174    }
175}