compio_io/framed/
write.rs1use std::{
2 io,
3 task::{Poll, ready},
4};
5
6use compio_buf::BufResult;
7use futures_util::{FutureExt, Sink};
8
9use crate::{
10 AsyncWrite, AsyncWriteExt, PinBoxFuture,
11 framed::{Framed, codec::Encoder, frame::Framer},
12};
13
14pub enum State<Io> {
15 Idle(Option<(Io, Vec<u8>)>),
16 Writing(PinBoxFuture<(Io, BufResult<(), Vec<u8>>)>),
17 Closing(PinBoxFuture<(Io, io::Result<()>, Vec<u8>)>),
18 Flushing(PinBoxFuture<(Io, io::Result<()>, Vec<u8>)>),
19}
20
21impl<Io> State<Io> {
22 fn take_idle(&mut self) -> (Io, Vec<u8>) {
23 match self {
24 State::Idle(idle) => idle.take().expect("Inconsistent state"),
25 _ => unreachable!("`Framed` not in idle state"),
26 }
27 }
28
29 pub fn buf(&mut self) -> Option<&mut Vec<u8>> {
30 match self {
31 State::Idle(Some((_, buf))) => Some(buf),
32 _ => None,
33 }
34 }
35
36 pub fn start_flush(&mut self)
37 where
38 Io: AsyncWrite + 'static,
39 {
40 let (mut io, buf) = self.take_idle();
41 let fut = Box::pin(async move {
42 let res = io.flush().await;
43 (io, res, buf)
44 });
45 *self = State::Flushing(fut);
46 }
47
48 pub fn start_close(&mut self)
49 where
50 Io: AsyncWrite + 'static,
51 {
52 let (mut io, buf) = self.take_idle();
53 let fut = Box::pin(async move {
54 let res = io.shutdown().await;
55 (io, res, buf)
56 });
57 *self = State::Closing(fut);
58 }
59
60 pub fn start_write(&mut self)
61 where
62 Io: AsyncWrite + 'static,
63 {
64 let (mut io, buf) = self.take_idle();
65 let fut = Box::pin(async move {
66 let res = io.write_all(buf).await;
67 (io, res)
68 });
69 *self = State::Writing(fut);
70 }
71
72 pub fn poll_sink(&mut self, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
74 let (io, res, buf) = match self {
75 State::Writing(fut) => {
76 let (io, BufResult(res, buf)) = ready!(fut.poll_unpin(cx));
77 (io, res, buf)
78 }
79 State::Closing(fut) | State::Flushing(fut) => ready!(fut.poll_unpin(cx)),
80 State::Idle(_) => {
81 return Poll::Ready(Ok(()));
82 }
83 };
84 *self = State::Idle(Some((io, buf)));
85 Poll::Ready(res)
86 }
87}
88
89impl<R, W, C, F, In, Out> Sink<In> for Framed<R, W, C, F, In, Out>
90where
91 W: AsyncWrite + 'static,
92 C: Encoder<In>,
93 F: Framer,
94 Self: Unpin,
95{
96 type Error = C::Error;
97
98 fn poll_ready(
99 self: std::pin::Pin<&mut Self>,
100 cx: &mut std::task::Context<'_>,
101 ) -> std::task::Poll<Result<(), Self::Error>> {
102 let this = self.get_mut();
103 match &mut this.write_state {
104 State::Idle(..) => Poll::Ready(Ok(())),
105 state => state.poll_sink(cx).map_err(C::Error::from),
106 }
107 }
108
109 fn start_send(self: std::pin::Pin<&mut Self>, item: In) -> Result<(), Self::Error> {
110 let this = self.get_mut();
111
112 let buf = this.write_state.buf().expect("`Framed` not in idle state");
113 buf.clear();
114 buf.reserve(64);
115 this.codec.encode(item, buf)?;
116 this.framer.enclose(buf);
117 this.write_state.start_write();
118
119 Ok(())
120 }
121
122 fn poll_flush(
123 self: std::pin::Pin<&mut Self>,
124 cx: &mut std::task::Context<'_>,
125 ) -> std::task::Poll<Result<(), Self::Error>> {
126 let this = self.get_mut();
127 match &mut this.write_state {
128 State::Idle(_) => {
129 this.write_state.start_flush();
130 this.write_state.poll_sink(cx).map_err(C::Error::from)
131 }
132 State::Writing(_) | State::Flushing(_) => {
133 this.write_state.poll_sink(cx).map_err(C::Error::from)
134 }
135 State::Closing(_) => unreachable!("`Framed` is closing, cannot flush"),
136 }
137 }
138
139 fn poll_close(
140 self: std::pin::Pin<&mut Self>,
141 cx: &mut std::task::Context<'_>,
142 ) -> std::task::Poll<Result<(), Self::Error>> {
143 let this = self.get_mut();
144 match &mut this.write_state {
145 state @ State::Idle(_) => {
146 state.start_close();
147 state.poll_sink(cx).map_err(C::Error::from)
148 }
149 _ => this.write_state.poll_sink(cx).map_err(C::Error::from),
150 }
151 }
152}