1use std::{
2 io,
3 marker::PhantomData,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use byteorder::{NetworkEndian, WriteBytesExt};
9use futures_core::ready;
10use futures_sink::Sink;
11use prost::Message;
12use tokio::io::AsyncWrite;
13
14use crate::{AsyncDestination, AsyncFrameDestination, Framed, SyncDestination};
15
16#[derive(Debug)]
18pub struct AsyncProstWriter<W, T, D> {
19 writer: W,
20 pub(crate) written: usize,
21 pub(crate) buffer: Vec<u8>,
22 from: PhantomData<T>,
23 dest: PhantomData<D>,
24}
25
26impl<W, T, D> AsyncProstWriter<W, T, D> {
27 pub fn new(writer: W) -> Self {
29 Self {
30 writer,
31 written: 0,
32 buffer: Vec::new(),
33 from: PhantomData,
34 dest: PhantomData,
35 }
36 }
37
38 pub fn get_ref(&self) -> &W {
40 &self.writer
41 }
42
43 pub fn get_mut(&mut self) -> &mut W {
45 &mut self.writer
46 }
47
48 pub fn into_inner(self) -> W {
52 self.writer
53 }
54
55 pub(crate) fn make_for<D2>(self) -> AsyncProstWriter<W, T, D2> {
56 AsyncProstWriter {
57 buffer: self.buffer,
58 writer: self.writer,
59 written: self.written,
60 from: self.from,
61 dest: PhantomData,
62 }
63 }
64}
65
66impl<W, T, D> Unpin for AsyncProstWriter<W, T, D> {}
67
68impl<W, T> Default for AsyncProstWriter<W, T, SyncDestination>
69where
70 W: Default,
71{
72 fn default() -> Self {
73 Self::from(W::default())
74 }
75}
76
77impl<W, T> From<W> for AsyncProstWriter<W, T, SyncDestination> {
78 fn from(writer: W) -> Self {
79 Self::new(writer)
80 }
81}
82
83impl<W, T> AsyncProstWriter<W, T, SyncDestination> {
84 pub fn for_async(self) -> AsyncProstWriter<W, T, AsyncDestination> {
86 self.make_for()
87 }
88
89 pub fn for_async_framed(self) -> AsyncProstWriter<W, T, AsyncFrameDestination> {
91 self.make_for()
92 }
93}
94
95#[doc(hidden)]
96pub trait ProstWriterFor<T> {
97 fn append(&mut self, item: T) -> Result<(), io::Error>;
98}
99
100impl<W, F: Framed> ProstWriterFor<F> for AsyncProstWriter<W, F, AsyncFrameDestination> {
101 fn append(&mut self, item: F) -> Result<(), io::Error> {
102 let size = item.encoded_len();
103 self.buffer.write_u32::<NetworkEndian>(size)?;
104 item.encode(&mut self.buffer)?;
105 Ok(())
106 }
107}
108
109impl<W, T: Message> ProstWriterFor<T> for AsyncProstWriter<W, T, AsyncDestination> {
110 fn append(&mut self, item: T) -> Result<(), io::Error> {
111 let size = item.encoded_len() as u32;
112
113 self.buffer.write_u32::<NetworkEndian>(size)?;
114 item.encode(&mut self.buffer)?;
115 Ok(())
116 }
117}
118
119impl<W, T> ProstWriterFor<T> for AsyncProstWriter<W, T, SyncDestination>
121where
122 T: Message,
123{
124 fn append(&mut self, item: T) -> Result<(), io::Error> {
125 item.encode(&mut self.buffer)?;
126 Ok(())
127 }
128}
129
130impl<W, T, D> Sink<T> for AsyncProstWriter<W, T, D>
131where
132 W: AsyncWrite + Unpin,
133 Self: ProstWriterFor<T>,
134{
135 type Error = io::Error;
136
137 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138 Poll::Ready(Ok(()))
139 }
140
141 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
142 if self.buffer.is_empty() {
143 }
151
152 self.append(item)
153 }
154
155 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156 let this = self.get_mut();
157
158 while this.written != this.buffer.len() {
160 let n =
161 ready!(Pin::new(&mut this.writer).poll_write(cx, &this.buffer[this.written..]))?;
162 this.written += n;
163 }
164
165 this.buffer.clear();
167 this.written = 0;
168 Pin::new(&mut this.writer).poll_flush(cx)
169 }
170
171 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
172 ready!(self.as_mut().poll_flush(cx))?;
173 Pin::new(&mut self.writer).poll_shutdown(cx)
174 }
175}