Skip to main content

compio_io/framed/
write.rs

1use std::{
2    io,
3    task::{Poll, ready},
4};
5
6use compio_buf::{BufResult, IoBufMut};
7use futures_util::{FutureExt, Sink};
8
9use crate::{
10    AsyncWrite, AsyncWriteExt, PinBoxFuture,
11    framed::{Framed, INCONSISTENT_ERROR, codec::Encoder, frame::Framer, panic_config_polled},
12};
13
14pub enum State<Io, B> {
15    Configuring(Option<Io>, Option<B>),
16    Idle(Option<(Io, B)>),
17    Writing(PinBoxFuture<(Io, BufResult<(), B>)>),
18    Closing(PinBoxFuture<(Io, io::Result<()>, B)>),
19    Flushing(PinBoxFuture<(Io, io::Result<()>, B)>),
20}
21
22macro_rules! initialize {
23    ($this:expr, $io:expr, $buf:expr) => {
24        *$this = State::Idle(Some((
25            $io.take().expect("io is empty"),
26            $buf.take().expect("buf is empty"),
27        )));
28    };
29}
30
31impl<Io> State<Io, Vec<u8>> {
32    pub fn empty() -> Self {
33        State::Configuring(None, Some(Vec::new()))
34    }
35}
36
37impl<Io, B> State<Io, B> {
38    pub fn with_io<I>(self, io: I) -> State<I, B> {
39        let State::Configuring(_, b) = self else {
40            panic_config_polled()
41        };
42        State::Configuring(Some(io), b)
43    }
44
45    pub fn with_buf<Buf: IoBufMut>(self, buf: Buf) -> State<Io, Buf> {
46        let State::Configuring(io, _) = self else {
47            panic_config_polled()
48        };
49        State::Configuring(io, Some(buf))
50    }
51
52    fn take_idle(&mut self) -> (Io, B) {
53        match self {
54            State::Configuring(io, buf) => {
55                initialize!(self, io, buf);
56                self.take_idle()
57            }
58            State::Idle(idle) => idle.take().expect(INCONSISTENT_ERROR),
59            _ => unreachable!("{}", INCONSISTENT_ERROR),
60        }
61    }
62
63    fn buf(&mut self) -> Option<&mut B> {
64        match self {
65            State::Idle(Some((_, buf))) => Some(buf),
66            State::Configuring(io, buf) => {
67                initialize!(self, io, buf);
68                self.buf()
69            }
70            _ => None,
71        }
72    }
73
74    fn poll_sink(&mut self, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
75        let (io, res, buf) = match self {
76            State::Configuring(io, buf) => {
77                initialize!(self, io, buf);
78                return Poll::Ready(Ok(()));
79            }
80            State::Idle(_) => {
81                return Poll::Ready(Ok(()));
82            }
83            State::Writing(fut) => {
84                let (io, BufResult(res, buf)) = ready!(fut.poll_unpin(cx));
85                (io, res, buf)
86            }
87            State::Closing(fut) | State::Flushing(fut) => ready!(fut.poll_unpin(cx)),
88        };
89        *self = State::Idle(Some((io, buf)));
90        Poll::Ready(res)
91    }
92}
93
94impl<Io: AsyncWrite + 'static, B: IoBufMut> State<Io, B> {
95    fn start_flush(&mut self) {
96        let (mut io, buf) = self.take_idle();
97        let fut = Box::pin(async move {
98            let res = io.flush().await;
99            (io, res, buf)
100        });
101        *self = State::Flushing(fut);
102    }
103
104    fn start_close(&mut self) {
105        let (mut io, buf) = self.take_idle();
106        let fut = Box::pin(async move {
107            let res = io.shutdown().await;
108            (io, res, buf)
109        });
110        *self = State::Closing(fut);
111    }
112
113    fn start_write(&mut self) {
114        let (mut io, buf) = self.take_idle();
115        let fut = Box::pin(async move {
116            let res = io.write_all(buf).await;
117            (io, res)
118        });
119        *self = State::Writing(fut);
120    }
121}
122
123impl<R, W, C, F, In, Out, B> Sink<In> for Framed<R, W, C, F, In, Out, B>
124where
125    W: AsyncWrite + 'static,
126    C: Encoder<In, B>,
127    F: Framer<B>,
128    B: IoBufMut,
129    Self: Unpin,
130{
131    type Error = C::Error;
132
133    fn poll_ready(
134        self: std::pin::Pin<&mut Self>,
135        cx: &mut std::task::Context<'_>,
136    ) -> std::task::Poll<Result<(), Self::Error>> {
137        let this = self.get_mut();
138        match &mut this.write_state {
139            State::Idle(..) => Poll::Ready(Ok(())),
140            state => state.poll_sink(cx).map_err(C::Error::from),
141        }
142    }
143
144    fn start_send(self: std::pin::Pin<&mut Self>, item: In) -> Result<(), Self::Error> {
145        let this = self.get_mut();
146
147        let buf = this.write_state.buf().expect("`Framed` not in idle state");
148        buf.clear();
149        let _ = buf.reserve(64);
150        if let Err(e) = this.codec.encode(item, buf) {
151            buf.clear();
152            return Err(e);
153        };
154        this.framer.enclose(buf);
155        this.write_state.start_write();
156
157        Ok(())
158    }
159
160    fn poll_flush(
161        mut self: std::pin::Pin<&mut Self>,
162        cx: &mut std::task::Context<'_>,
163    ) -> std::task::Poll<Result<(), Self::Error>> {
164        let this = self.as_mut().get_mut();
165        match &mut this.write_state {
166            State::Configuring(io, buf) => {
167                initialize!(&mut this.write_state, io, buf);
168                self.poll_flush(cx)
169            }
170            State::Idle(_) => {
171                this.write_state.start_flush();
172                this.write_state.poll_sink(cx).map_err(C::Error::from)
173            }
174            State::Writing(_) | State::Flushing(_) => {
175                this.write_state.poll_sink(cx).map_err(C::Error::from)
176            }
177            State::Closing(_) => unreachable!("`Framed` is closing, cannot flush"),
178        }
179    }
180
181    fn poll_close(
182        self: std::pin::Pin<&mut Self>,
183        cx: &mut std::task::Context<'_>,
184    ) -> std::task::Poll<Result<(), Self::Error>> {
185        let this = self.get_mut();
186        match &mut this.write_state {
187            state @ State::Idle(_) => {
188                state.start_close();
189                state.poll_sink(cx).map_err(C::Error::from)
190            }
191            _ => this.write_state.poll_sink(cx).map_err(C::Error::from),
192        }
193    }
194}