1use std::{
2 io,
3 task::{Poll, ready},
4};
5
6use compio_buf::{BufResult, IoBufMut};
7use futures_util::{FutureExt, Sink};
8
9use crate::{
10 AsyncWrite, AsyncWriteExt, PinBoxFuture,
11 framed::{Framed, INCONSISTENT_ERROR, codec::Encoder, frame::Framer, panic_config_polled},
12};
13
14pub enum State<Io, B> {
15 Configuring(Option<Io>, Option<B>),
16 Idle(Option<(Io, B)>),
17 Writing(PinBoxFuture<(Io, BufResult<(), B>)>),
18 Closing(PinBoxFuture<(Io, io::Result<()>, B)>),
19 Flushing(PinBoxFuture<(Io, io::Result<()>, B)>),
20}
21
22macro_rules! initialize {
23 ($this:expr, $io:expr, $buf:expr) => {
24 *$this = State::Idle(Some((
25 $io.take().expect("io is empty"),
26 $buf.take().expect("buf is empty"),
27 )));
28 };
29}
30
31impl<Io> State<Io, Vec<u8>> {
32 pub fn empty() -> Self {
33 State::Configuring(None, Some(Vec::new()))
34 }
35}
36
37impl<Io, B> State<Io, B> {
38 pub fn with_io<I>(self, io: I) -> State<I, B> {
39 let State::Configuring(_, b) = self else {
40 panic_config_polled()
41 };
42 State::Configuring(Some(io), b)
43 }
44
45 pub fn with_buf<Buf: IoBufMut>(self, buf: Buf) -> State<Io, Buf> {
46 let State::Configuring(io, _) = self else {
47 panic_config_polled()
48 };
49 State::Configuring(io, Some(buf))
50 }
51
52 fn take_idle(&mut self) -> (Io, B) {
53 match self {
54 State::Configuring(io, buf) => {
55 initialize!(self, io, buf);
56 self.take_idle()
57 }
58 State::Idle(idle) => idle.take().expect(INCONSISTENT_ERROR),
59 _ => unreachable!("{}", INCONSISTENT_ERROR),
60 }
61 }
62
63 fn buf(&mut self) -> Option<&mut B> {
64 match self {
65 State::Idle(Some((_, buf))) => Some(buf),
66 State::Configuring(io, buf) => {
67 initialize!(self, io, buf);
68 self.buf()
69 }
70 _ => None,
71 }
72 }
73
74 fn poll_sink(&mut self, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
75 let (io, res, buf) = match self {
76 State::Configuring(io, buf) => {
77 initialize!(self, io, buf);
78 return Poll::Ready(Ok(()));
79 }
80 State::Idle(_) => {
81 return Poll::Ready(Ok(()));
82 }
83 State::Writing(fut) => {
84 let (io, BufResult(res, buf)) = ready!(fut.poll_unpin(cx));
85 (io, res, buf)
86 }
87 State::Closing(fut) | State::Flushing(fut) => ready!(fut.poll_unpin(cx)),
88 };
89 *self = State::Idle(Some((io, buf)));
90 Poll::Ready(res)
91 }
92}
93
94impl<Io: AsyncWrite + 'static, B: IoBufMut> State<Io, B> {
95 fn start_flush(&mut self) {
96 let (mut io, buf) = self.take_idle();
97 let fut = Box::pin(async move {
98 let res = io.flush().await;
99 (io, res, buf)
100 });
101 *self = State::Flushing(fut);
102 }
103
104 fn start_close(&mut self) {
105 let (mut io, buf) = self.take_idle();
106 let fut = Box::pin(async move {
107 let res = io.shutdown().await;
108 (io, res, buf)
109 });
110 *self = State::Closing(fut);
111 }
112
113 fn start_write(&mut self) {
114 let (mut io, buf) = self.take_idle();
115 let fut = Box::pin(async move {
116 let res = io.write_all(buf).await;
117 (io, res)
118 });
119 *self = State::Writing(fut);
120 }
121}
122
123impl<R, W, C, F, In, Out, B> Sink<In> for Framed<R, W, C, F, In, Out, B>
124where
125 W: AsyncWrite + 'static,
126 C: Encoder<In, B>,
127 F: Framer<B>,
128 B: IoBufMut,
129 Self: Unpin,
130{
131 type Error = C::Error;
132
133 fn poll_ready(
134 self: std::pin::Pin<&mut Self>,
135 cx: &mut std::task::Context<'_>,
136 ) -> std::task::Poll<Result<(), Self::Error>> {
137 let this = self.get_mut();
138 match &mut this.write_state {
139 State::Idle(..) => Poll::Ready(Ok(())),
140 state => state.poll_sink(cx).map_err(C::Error::from),
141 }
142 }
143
144 fn start_send(self: std::pin::Pin<&mut Self>, item: In) -> Result<(), Self::Error> {
145 let this = self.get_mut();
146
147 let buf = this.write_state.buf().expect("`Framed` not in idle state");
148 buf.clear();
149 let _ = buf.reserve(64);
150 if let Err(e) = this.codec.encode(item, buf) {
151 buf.clear();
152 return Err(e);
153 };
154 this.framer.enclose(buf);
155 this.write_state.start_write();
156
157 Ok(())
158 }
159
160 fn poll_flush(
161 mut self: std::pin::Pin<&mut Self>,
162 cx: &mut std::task::Context<'_>,
163 ) -> std::task::Poll<Result<(), Self::Error>> {
164 let this = self.as_mut().get_mut();
165 match &mut this.write_state {
166 State::Configuring(io, buf) => {
167 initialize!(&mut this.write_state, io, buf);
168 self.poll_flush(cx)
169 }
170 State::Idle(_) => {
171 this.write_state.start_flush();
172 this.write_state.poll_sink(cx).map_err(C::Error::from)
173 }
174 State::Writing(_) | State::Flushing(_) => {
175 this.write_state.poll_sink(cx).map_err(C::Error::from)
176 }
177 State::Closing(_) => unreachable!("`Framed` is closing, cannot flush"),
178 }
179 }
180
181 fn poll_close(
182 self: std::pin::Pin<&mut Self>,
183 cx: &mut std::task::Context<'_>,
184 ) -> std::task::Poll<Result<(), Self::Error>> {
185 let this = self.get_mut();
186 match &mut this.write_state {
187 state @ State::Idle(_) => {
188 state.start_close();
189 state.poll_sink(cx).map_err(C::Error::from)
190 }
191 _ => this.write_state.poll_sink(cx).map_err(C::Error::from),
192 }
193 }
194}